diff --git a/README.md b/README.md index d10dd86a..04df88ea 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ![logo](docs/assets/tf.net.logo.png) -**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. +**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. [![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) [![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) @@ -16,7 +16,7 @@ TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). Install-Package TensorFlow.NET ### Install tensorflow binary ### For CPU version PM> Install-Package SciSharp.TensorFlow.Redist + ### For GPU version (CUDA and cuDNN are required) PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU ``` -Import TF.NET. - -```cs -using Tensorflow; -``` - -Add two constants: -```cs -// Create a Constant op -var a = tf.constant(4.0f); -var b = tf.constant(5.0f); -var c = tf.add(a, b); - -using (var sess = tf.Session()) -{ - var o = sess.run(c); -} -``` +Import TF.NET in your project. -Feed placeholder: ```cs -// Create a placeholder op -var a = tf.placeholder(tf.float32); -var b = tf.placeholder(tf.float32); -var c = tf.add(a, b); - -using(var sess = tf.Session()) -{ - var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); -} +using static Tensorflow.Binding; ``` Linear Regression: @@ -91,39 +66,40 @@ var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); var init = tf.global_variables_initializer(); // Start training -with(tf.Session(), sess => +using(tf.Session()) { // Run the initializer sess.run(init); - + // Fit all training data for (int epoch = 0; epoch < training_epochs; epoch++) { foreach (var (x, y) in zip(train_X, train_Y)) - sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y)); - + sess.run(optimizer, (X, x), (Y, y)); + // Display logs per epoch step if ((epoch + 1) % display_step == 0) { - var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); + var c = sess.run(cost, (X, train_X), (Y, train_Y)); Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); } - - Console.WriteLine("Optimization Finished!"); - var training_cost = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); - Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); - - // Testing example - var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); - var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); - Console.WriteLine("Testing... (Mean square loss Comparison)"); - - var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), new FeedItem(Y, test_Y)); - Console.WriteLine($"Testing cost={testing_cost}"); - - var diff = Math.Abs((float)training_cost - (float)testing_cost); - Console.WriteLine($"Absolute mean square loss difference: {diff}"); } + + Console.WriteLine("Optimization Finished!"); + var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); + Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); + + // Testing example + var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); + var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); + Console.WriteLine("Testing... (Mean square loss Comparison)"); + var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), + (X, test_X), (Y, test_Y)); + Console.WriteLine($"Testing cost={testing_cost}"); + var diff = Math.Abs((float)training_cost - (float)testing_cost); + Console.WriteLine($"Absolute mean square loss difference: {diff}"); + + return diff < 0.01; }); ``` diff --git a/src/TensorFlowHub/TensorFlowHub.csproj b/src/TensorFlowHub/TensorFlowHub.csproj index 97e6497e..16e22183 100644 --- a/src/TensorFlowHub/TensorFlowHub.csproj +++ b/src/TensorFlowHub/TensorFlowHub.csproj @@ -17,6 +17,6 @@ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 - + \ No newline at end of file diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index 1656edd0..adf0b86f 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -59,6 +59,6 @@ namespace Tensorflow } [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_Version(); + public static extern IntPtr TF_Version(); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index a60c0413..cee941ed 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -14,11 +14,16 @@ limitations under the License. ******************************************************************************/ +using static Tensorflow.ops; + namespace Tensorflow { public partial class tensorflow { public graph_util_impl graph_util => new graph_util_impl(); + + public GraphKeys GraphKeys { get; } = new GraphKeys(); + public Graph get_default_graph() { return ops.get_default_graph(); diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs new file mode 100644 index 00000000..e2e3206b --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.image.cs @@ -0,0 +1,61 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.IO; + +namespace Tensorflow +{ + public partial class tensorflow + { + public image_internal image = new image_internal(); + + public class image_internal + { + public Tensor decode_jpeg(Tensor contents, + int channels = 0, + int ratio = 1, + bool fancy_upscaling = true, + bool try_recover_truncated = false, + float acceptable_fraction = 1, + string dct_method = "", + string name = null) + => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, + fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, + acceptable_fraction: acceptable_fraction, dct_method: dct_method); + + public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) + => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name); + + public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) + => gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name); + + public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null, bool expand_animations = true) + => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, + name: name, expand_animations: expand_animations); + + /// + /// Convenience function to check if the 'contents' encodes a JPEG image. + /// + /// + /// + /// + public static Tensor is_jpeg(Tensor contents, string name = null) + => image_ops_impl.is_jpeg(contents, name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index f6fa380c..c9294653 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -52,5 +52,13 @@ namespace Tensorflow stddev: stddev, seed: seed, dtype: dtype); + + public IInitializer random_normal_initializer(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.DtInvalid) => new RandomNormal(mean: mean, + stddev: stddev, + seed: seed, + dtype: dtype); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs index 5980cf81..394357de 100644 --- a/src/TensorFlowNET.Core/APIs/tf.io.cs +++ b/src/TensorFlowNET.Core/APIs/tf.io.cs @@ -24,8 +24,6 @@ namespace Tensorflow public GFile gfile = new GFile(); public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); - public gen_image_ops image => new gen_image_ops(); - public void import_graph_def(GraphDef graph_def, Dictionary input_map = null, string[] return_elements = null, diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index e5ecb908..01789617 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -148,6 +148,24 @@ namespace Tensorflow }); } + /// + /// Local Response Normalization. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor lrn(Tensor input, int depth_radius = 5, int bias = 1, + int alpha = 1, float beta = 0.5f, string name = null) + => gen_nn_ops.local_response_normalization(input, depth_radius: depth_radius, bias: bias, + alpha: alpha, beta: beta, name: name); + + public Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) + => nn_ops.leaky_relu(features, alpha: alpha, name: name); + public rnn_cell_impl rnn_cell => new rnn_cell_impl(); public Tensor softmax(Tensor logits, int axis = -1, string name = null) diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index ba678327..ec533af4 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -33,5 +33,13 @@ namespace Tensorflow /// The scope name. public ops.NameScope name_scope(string name, string default_name = "", object values = null) => new ops.NameScope(name, default_name, values); + + /// + /// Does nothing. Only useful as a placeholder for control edges. + /// + /// + /// + public Tensor no_op(string name = null) + => gen_control_flow_ops.no_op(name: name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs new file mode 100644 index 00000000..38d92803 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.IO; + +namespace Tensorflow +{ + public partial class tensorflow + { + public strings_internal strings = new strings_internal(); + public class strings_internal + { + public Tensor substr(Tensor input, int pos, int len, + string name = null, string @uint = "BYTE") + => string_ops.substr(input, pos, len, name: name, @uint: @uint); + } + } +} diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs similarity index 94% rename from src/TensorFlowNET.Core/Train/tf.optimizers.cs rename to src/TensorFlowNET.Core/APIs/tf.train.cs index 0c801f90..a943308b 100644 --- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -31,6 +31,9 @@ namespace Tensorflow public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name); + public ExponentialMovingAverage ExponentialMovingAverage(float decay) + => new ExponentialMovingAverage(decay); + public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); public string write_graph(Graph graph, string logdir, string name, bool as_text = true) diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index e855535a..b3c5bf43 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Collections.Generic; +using static Tensorflow.Binding; namespace Tensorflow { @@ -22,7 +23,7 @@ namespace Tensorflow { public VariableV1[] global_variables(string scope = null) { - return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List) + return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List) .ToArray(); } @@ -32,6 +33,14 @@ namespace Tensorflow return variables.variables_initializer(g.ToArray()); } + /// + /// Returns all variables created with `trainable=True`. + /// + /// + /// + public VariableV1[] trainable_variables(string scope = null) + => (variables.trainable_variables() as List).ToArray(); + public RefVariable get_variable(string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Core/Assembly/Properties.cs b/src/TensorFlowNET.Core/Assembly/Properties.cs new file mode 100644 index 00000000..28aee65e --- /dev/null +++ b/src/TensorFlowNET.Core/Assembly/Properties.cs @@ -0,0 +1,4 @@ +using System.Runtime.CompilerServices; +#if DEBUG +[assembly: InternalsVisibleTo("TensorFlowNET.UnitTest, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] +#endif diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index da2cdf6e..d723283f 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -178,13 +178,18 @@ namespace Tensorflow public static IEnumerable<(TKey, TValue)> enumerate(KeyValuePair[] values) { - foreach (var item in values) + var len = values.Length; + for (var i = 0; i < len; i++) + { + var item = values[i]; yield return (item.Key, item.Value); + } } public static IEnumerable<(int, T)> enumerate(IList values) { - for (int i = 0; i < values.Count; i++) + var len = values.Count; + for (int i = 0; i < len; i++) yield return (i, values[i]); } @@ -308,15 +313,14 @@ namespace Tensorflow public static IEnumerable TupleToEnumerable(object tuple) { Type t = tuple.GetType(); - if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) + if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) { var flds = t.GetFields(); - for(int i = 0; i < flds.Length;i++) + for (int i = 0; i < flds.Length; i++) { yield return flds[i].GetValue(tuple); } - } - else + } else { throw new System.Exception("Expected Tuple."); } @@ -329,12 +333,9 @@ namespace Tensorflow public static bool isinstance(object Item1, object tuple) { - var tup = TupleToEnumerable(tuple); - foreach(var t in tup) - { - if(isinstance(Item1, (Type)t)) + foreach (var t in TupleToEnumerable(tuple)) + if (isinstance(Item1, (Type) t)) return true; - } return false; } } diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index dbe576b8..c08d3175 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -15,58 +15,116 @@ ******************************************************************************/ using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using NumSharp.Backends.Unmanaged; +using static Tensorflow.c_api; namespace Tensorflow { + /// + /// Represents a TF_Buffer that can be passed to Tensorflow. + /// public class Buffer : DisposableObject { - private TF_Buffer buffer => Marshal.PtrToStructure(_handle); + private unsafe TF_Buffer buffer + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => *bufferptr; + } + + private unsafe TF_Buffer* bufferptr + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (TF_Buffer*) _handle; + } - public byte[] Data + /// + /// The memory block representing this buffer. + /// + /// The deallocator is set to null. + public UnmanagedMemoryBlock MemoryBlock { - get + get { - var data = new byte[buffer.length]; - if (data.Length > 0) - Marshal.Copy(buffer.data, data, 0, data.Length); - return data; + unsafe + { + EnsureNotDisposed(); + var buff = (TF_Buffer*) _handle; + return new UnmanagedMemoryBlock((byte*) buff->data.ToPointer(), (long) buff->length); + } } } - public int Length => (int)buffer.length; - - public Buffer() + /// + /// The bytes length of this buffer. + /// + public ulong Length { - _handle = c_api.TF_NewBuffer(); + get + { + EnsureNotDisposed(); + return buffer.length; + } } - public Buffer(IntPtr handle) + public Buffer() => _handle = TF_NewBuffer(); + + internal Buffer(IntPtr handle) { + if (handle == IntPtr.Zero) + throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); + _handle = handle; } - public Buffer(byte[] data) - { - var dst = Marshal.AllocHGlobal(data.Length); - Marshal.Copy(data, 0, dst, data.Length); + public Buffer(byte[] data) : this(_toBuffer(data)) + { } - _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); + private static IntPtr _toBuffer(byte[] data) + { + if (data == null) + throw new ArgumentNullException(nameof(data)); - Marshal.FreeHGlobal(dst); + unsafe + { + fixed (byte* src = data) + return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); + } } public static implicit operator IntPtr(Buffer buffer) { + buffer.EnsureNotDisposed(); return buffer._handle; } - public static implicit operator byte[](Buffer buffer) + public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. + + /// + /// Copies this buffer's contents onto a array. + /// + public byte[] ToArray() { - return buffer.Data; + EnsureNotDisposed(); + + unsafe + { + var len = buffer.length; + if (len == 0) + return Array.Empty(); + + byte[] data = new byte[len]; + fixed (byte* dst = data) + System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); + + return data; + } } - protected override void DisposeUnManagedState(IntPtr handle) - => c_api.TF_DeleteBuffer(handle); + protected override void DisposeUnmanagedResources(IntPtr handle) + { + TF_DeleteBuffer(handle); + } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 7e416e6d..a7fc5a2c 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -16,6 +16,8 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text; namespace Tensorflow @@ -26,52 +28,71 @@ namespace Tensorflow public abstract class DisposableObject : IDisposable { protected IntPtr _handle; + protected bool _disposed; - protected DisposableObject() { } + [SuppressMessage("ReSharper", "UnusedMember.Global")] + protected DisposableObject() + { } - public DisposableObject(IntPtr handle) - { - _handle = handle; - } + protected DisposableObject(IntPtr handle) + => _handle = handle; - protected virtual void DisposeManagedState() + [SuppressMessage("ReSharper", "InvertIf")] + private void internal_dispose(bool disposing) { - } + if (_disposed) + return; - protected abstract void DisposeUnManagedState(IntPtr handle); + _disposed = true; - protected virtual void Dispose(bool disposing) - { + //first handle managed, they might use the unmanaged resources. if (disposing) - { - // free unmanaged resources (unmanaged objects) and override a finalizer below. - if (_handle != IntPtr.Zero) - { - // dispose managed state (managed objects). - DisposeManagedState(); - - // set large fields to null. - DisposeUnManagedState(_handle); + // dispose managed state (managed objects). + DisposeManagedResources(); - _handle = IntPtr.Zero; - } + //free unmanaged memory + if (_handle != IntPtr.Zero) + { + DisposeUnmanagedResources(_handle); + _handle = IntPtr.Zero; } } - // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected virtual void DisposeManagedResources() + { } + + /// + /// Dispose any unmanaged resources related to given . + /// + protected abstract void DisposeUnmanagedResources(IntPtr handle); + ~DisposableObject() { - // Do not change this code. Put cleanup code in Dispose(bool disposing) above. - Dispose(false); + internal_dispose(false); } - // This code added to correctly implement the disposable pattern. public void Dispose() { - // Do not change this code. Put cleanup code in Dispose(bool disposing) above. - Dispose(true); - // uncomment the following line if the finalizer is overridden above. - GC.SuppressFinalize(this); + lock(this) + { + internal_dispose(true); + GC.SuppressFinalize(this); + } + } + + /// + /// If is then throws + /// + /// When is + [MethodImpl(MethodImplOptions.AggressiveInlining)] + protected void EnsureNotDisposed() + { + if (_disposed) + throw new ObjectDisposedException($"Unable to access disposed object, Type: {GetType().Name}"); } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs index 4ee43d35..700e1236 100644 --- a/src/TensorFlowNET.Core/Eager/Context.cs +++ b/src/TensorFlowNET.Core/Eager/Context.cs @@ -2,12 +2,10 @@ namespace Tensorflow.Eager { - public class Context : IDisposable + public class Context : DisposableObject { - private IntPtr _handle; - - public static int GRAPH_MODE = 0; - public static int EAGER_MODE = 1; + public const int GRAPH_MODE = 0; + public const int EAGER_MODE = 1; public int default_execution_mode; @@ -17,19 +15,16 @@ namespace Tensorflow.Eager status.Check(true); } - public void Dispose() - { - c_api.TFE_DeleteContext(_handle); - } + /// + /// Dispose any unmanaged resources related to given . + /// + protected sealed override void DisposeUnmanagedResources(IntPtr handle) + => c_api.TFE_DeleteContext(_handle); - public bool executing_eagerly() - { - return false; - } - public static implicit operator IntPtr(Context ctx) - { - return ctx._handle; - } + public bool executing_eagerly() => false; + + public static implicit operator IntPtr(Context ctx) + => ctx._handle; } } diff --git a/src/TensorFlowNET.Core/Eager/ContextOptions.cs b/src/TensorFlowNET.Core/Eager/ContextOptions.cs index 4bffddf6..12c4cdfc 100644 --- a/src/TensorFlowNET.Core/Eager/ContextOptions.cs +++ b/src/TensorFlowNET.Core/Eager/ContextOptions.cs @@ -1,24 +1,22 @@ using System; +using System.IO; namespace Tensorflow.Eager { - public class ContextOptions : IDisposable + public class ContextOptions : DisposableObject { - private IntPtr _handle; + public ContextOptions() : base(c_api.TFE_NewContextOptions()) + { } - public ContextOptions() - { - _handle = c_api.TFE_NewContextOptions(); - } + /// + /// Dispose any unmanaged resources related to given . + /// + protected sealed override void DisposeUnmanagedResources(IntPtr handle) + => c_api.TFE_DeleteContextOptions(_handle); - public void Dispose() - { - c_api.TFE_DeleteContextOptions(_handle); - } - public static implicit operator IntPtr(ContextOptions opts) - { - return opts._handle; - } + public static implicit operator IntPtr(ContextOptions opts) + => opts._handle; } + } diff --git a/src/TensorFlowNET.Core/Exceptions/KeyError.cs b/src/TensorFlowNET.Core/Exceptions/KeyError.cs index 8cecae76..949fd309 100644 --- a/src/TensorFlowNET.Core/Exceptions/KeyError.cs +++ b/src/TensorFlowNET.Core/Exceptions/KeyError.cs @@ -2,7 +2,7 @@ namespace Tensorflow { - public class KeyError : Exception + public class KeyError : TensorflowException { public KeyError() : base() { diff --git a/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs b/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs index 09a02a4a..6f7e4f48 100644 --- a/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs +++ b/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs @@ -2,7 +2,7 @@ namespace Tensorflow { - public class RuntimeError : Exception + public class RuntimeError : TensorflowException { public RuntimeError() : base() { diff --git a/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs b/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs new file mode 100644 index 00000000..ee9eca69 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs @@ -0,0 +1,36 @@ +using System; +using System.Runtime.Serialization; + +namespace Tensorflow +{ + + /// + /// Serves as a base class to all exceptions of Tensorflow.NET. + /// + [Serializable] + public class TensorflowException : Exception + { + /// Initializes a new instance of the class. + public TensorflowException() + { } + + /// Initializes a new instance of the class with serialized data. + /// The that holds the serialized object data about the exception being thrown. + /// The that contains contextual information about the source or destination. + /// The info parameter is null. + /// The class name is null or is zero (0). + protected TensorflowException(SerializationInfo info, StreamingContext context) : base(info, context) + { } + + /// Initializes a new instance of the class with a specified error message. + /// The message that describes the error. + public TensorflowException(string message) : base(message) + { } + + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified. + public TensorflowException(string message, Exception innerException) : base(message, innerException) + { } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Exceptions/TypeError.cs b/src/TensorFlowNET.Core/Exceptions/TypeError.cs index a4c37988..42c8e3a0 100644 --- a/src/TensorFlowNET.Core/Exceptions/TypeError.cs +++ b/src/TensorFlowNET.Core/Exceptions/TypeError.cs @@ -2,7 +2,7 @@ namespace Tensorflow { - public class TypeError : Exception + public class TypeError : TensorflowException { public TypeError() : base() { diff --git a/src/TensorFlowNET.Core/Exceptions/ValueError.cs b/src/TensorFlowNET.Core/Exceptions/ValueError.cs index 825d27a1..0d6fb4e3 100644 --- a/src/TensorFlowNET.Core/Exceptions/ValueError.cs +++ b/src/TensorFlowNET.Core/Exceptions/ValueError.cs @@ -2,7 +2,7 @@ namespace Tensorflow { - public class ValueError : Exception + public class ValueError : TensorflowException { public ValueError() : base() { diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs index dc3955b2..145a3058 100644 --- a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs +++ b/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs @@ -6,10 +6,5 @@ { } - - ~ScopedTFImportGraphDefOptions() - { - base.Dispose(); - } } } diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs index fde7d2bb..dc1236e3 100644 --- a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs +++ b/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs @@ -13,10 +13,5 @@ namespace Tensorflow.Framework.Models { } - - ~ScopedTFImportGraphDefResults() - { - base.Dispose(); - } } } diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFStatus.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFStatus.cs index 068cfbee..a427c994 100644 --- a/src/TensorFlowNET.Core/Framework/Models/ScopedTFStatus.cs +++ b/src/TensorFlowNET.Core/Framework/Models/ScopedTFStatus.cs @@ -5,10 +5,5 @@ public ScopedTFStatus() : base() { } - - ~ScopedTFStatus() - { - base.Dispose(); - } } } diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index db9dba38..d7d7ef7e 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -95,7 +95,7 @@ namespace Tensorflow break; case KindOneofCase.BytesList: //var proto_type = ops.get_collection_proto_type(key) - if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) + if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) { foreach (var value in col.Value.BytesList.Value) { @@ -146,7 +146,7 @@ namespace Tensorflow } } - var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope: scope_to_prepend_to_names); var var_list = new Dictionary(); variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); @@ -180,7 +180,7 @@ namespace Tensorflow var graph = ops.get_default_graph(); var var_list = new Dictionary(); - var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) as List; + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List; if (variables != null) { diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs index 9f9b4ad7..8a2bc5c3 100644 --- a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs +++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System.Collections.Generic; +using System.IO; +using Tensorflow.Util; namespace Tensorflow { @@ -27,12 +29,12 @@ namespace Tensorflow if(_registered_ops == null) { _registered_ops = new Dictionary(); - var handle = c_api.TF_GetAllOpList(); - var buffer = new Buffer(handle); - var op_list = OpList.Parser.ParseFrom(buffer); - - foreach (var op_def in op_list.Op) - _registered_ops[op_def.Name] = op_def; + using (var buffer = new Buffer(c_api.TF_GetAllOpList())) + { + var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + foreach (var op_def in op_list.Op) + _registered_ops[op_def.Name] = op_def; + } } return _registered_ops; diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs index 6c9f6b18..66419b3e 100644 --- a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -14,49 +14,62 @@ limitations under the License. ******************************************************************************/ +using System; using System.Collections.Generic; using System.Linq; using static Tensorflow.Binding; namespace Tensorflow { - public class DefaultGraphStack + + /// + /// Serves as a stack for determining current default graph. + /// + public class DefaultGraphStack { - List stack = new List(); + private readonly List _stack = new List(); public void set_controller(Graph @default) { - if (!stack.Exists(x => x.Graph == @default)) - stack.Add(new StackModel { Graph = @default, IsDefault = true }); + if (!_stack.Exists(x => x.Graph == @default)) + _stack.Add(new StackModel {Graph = @default, IsDefault = true}); - foreach (var s in stack) + foreach (var s in _stack) s.IsDefault = s.Graph == @default; } public Graph get_controller() { - if (stack.Count(x => x.IsDefault) == 0) - stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); + if (_stack.Count(x => x.IsDefault) == 0) + _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); + for (var i = _stack.Count - 1; i >= 0; i--) + { + var x = _stack[i]; + if (x.IsDefault) + return x.Graph; + } - return stack.Last(x => x.IsDefault).Graph; + throw new TensorflowException("Unable to find a default graph"); } public bool remove(Graph g) { - var sm = stack.FirstOrDefault(x => x.Graph == g); - if (sm == null) return false; - return stack.Remove(sm); + if (_stack.Count == 0) + return false; + + var sm = _stack.Find(model => model.Graph == g); + return sm != null && _stack.Remove(sm); } public void reset() { - stack.Clear(); + _stack.Clear(); } - } - public class StackModel - { - public Graph Graph { get; set; } - public bool IsDefault { get; set; } + private class StackModel + { + public Graph Graph { get; set; } + public bool IsDefault { get; set; } + } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index 4a3ac793..c97e1b6f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using Tensorflow.Operations; @@ -66,8 +67,9 @@ namespace Tensorflow /// within the context should have control dependencies on /// `control_inputs`. /// + [SuppressMessage("ReSharper", "CoVariantArrayConversion")] public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) - => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray()); + => control_dependencies((object[])control_inputs); /// /// Returns a context manager that specifies control dependencies. diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 17828c73..4a7e0ed8 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -14,6 +14,9 @@ limitations under the License. ******************************************************************************/ +using System.IO; +using Tensorflow.Util; + namespace Tensorflow { public partial class Graph @@ -23,21 +26,19 @@ namespace Tensorflow var buffer = new Buffer(); c_api.TF_GraphToGraphDef(_handle, buffer, s); s.Check(true); - // var def = GraphDef.Parser.ParseFrom(buffer); - // buffer.Dispose(); return buffer; } private GraphDef _as_graph_def(bool add_shapes = false) { - var status = new Status(); - var buffer = ToGraphDef(status); - status.Check(true); - status.Dispose(); - - var def = GraphDef.Parser.ParseFrom(buffer); - buffer.Dispose(); + GraphDef def; + using (var status = new Status()) + using (var buffer = ToGraphDef(status)) + { + status.Check(true); + def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } // Strip the experimental library field iff it's empty. // if(def.Library.Function.Count == 0) @@ -45,7 +46,7 @@ namespace Tensorflow return def; } - public GraphDef as_graph_def(bool add_shapes = false) + public GraphDef as_graph_def(bool add_shapes = false) => _as_graph_def(add_shapes); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index 82695527..0b2dc0e6 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -30,11 +30,10 @@ namespace Tensorflow var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); - for (int i = 0; i < num_return_outputs; i++) - { - var handle = return_output_handle + i * size; - return_outputs[i] = Marshal.PtrToStructure(handle); - } + + var tf_output_ptr = (TF_Output*) return_output_handle; + for (int i = 0; i < num_return_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); Marshal.FreeHGlobal(return_output_handle); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index 436afcc9..0e28dd9a 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -18,6 +18,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow @@ -30,7 +31,7 @@ namespace Tensorflow using (var status = new Status()) { c_api.TF_GraphGetOpDef(_handle, type, buffer, status); - return OpDef.Parser.ParseFrom(buffer.Data); + return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } @@ -39,16 +40,20 @@ namespace Tensorflow return c_api.TF_NewOperation(_handle, opType, opName); } - public unsafe Operation[] ReturnOperations(IntPtr results) + public Operation[] ReturnOperations(IntPtr results) { TF_Operation return_oper_handle = new TF_Operation(); int num_return_opers = 0; c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); Operation[] return_opers = new Operation[num_return_opers]; + var tf_op_size = Marshal.SizeOf(); for (int i = 0; i < num_return_opers; i++) { - var handle = return_oper_handle.node + Marshal.SizeOf() * i; - return_opers[i] = new Operation(*(IntPtr*)handle); + unsafe + { + var handle = return_oper_handle.node + tf_op_size * i; + return_opers[i] = new Operation(*(IntPtr*)handle); + } } return return_opers; @@ -67,7 +72,7 @@ namespace Tensorflow public ITensorOrOperation[] get_operations() { - return _nodes_by_name.Values.Select(x => x).ToArray(); + return _nodes_by_name.Values.ToArray(); } /// @@ -81,7 +86,7 @@ namespace Tensorflow public ITensorOrOperation _get_operation_by_name_unsafe(string name) { - return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; + return _nodes_by_name.TryGetValue(name, out var val) ? val : null; } public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 77926dca..0dfb68db 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -23,57 +23,58 @@ using static Tensorflow.Binding; namespace Tensorflow { + /* + A TensorFlow computation, represented as a dataflow graph. + + A `Graph` contains a set of + `tf.Operation` objects, + which represent units of computation; and + `tf.Tensor` objects, which represent + the units of data that flow between operations. + + A default `Graph` is always registered, and accessible by calling + `tf.get_default_graph`. + To add an operation to the default graph, simply call one of the functions + that defines a new `Operation`: + + ```python + c = tf.constant(4.0) + assert c.graph is tf.get_default_graph() + ``` + + Another typical usage involves the + `tf.Graph.as_default` + context manager, which overrides the current default graph for the + lifetime of the context: + + ```python + g = tf.Graph() + with g.as_default(): + # Define operations and tensors in `g`. + c = tf.constant(30.0) + assert c.graph is g + ``` + + Important note: This class *is not* thread-safe for graph construction. All + operations should be created from a single thread, or external + synchronization must be provided. Unless otherwise specified, all methods + are not thread-safe. + + A `Graph` instance supports an arbitrary number of "collections" + that are identified by name. For convenience when building a large + graph, collections can store groups of related objects: for + example, the `tf.Variable` uses a collection (named + `tf.GraphKeys.GLOBAL_VARIABLES`) for + all variables that are created during the construction of a graph. The caller + may define additional collections by specifying a new name. + */ + /// - /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. - /// This leads to a low-level programming model in which you first define the dataflow graph, - /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. - /// https://www.tensorflow.org/guide/graphs + /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. + /// This leads to a low-level programming model in which you first define the dataflow graph, + /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// - /* - A TensorFlow computation, represented as a dataflow graph. - - A `Graph` contains a set of - `tf.Operation` objects, - which represent units of computation; and - `tf.Tensor` objects, which represent - the units of data that flow between operations. - - A default `Graph` is always registered, and accessible by calling - `tf.get_default_graph`. - To add an operation to the default graph, simply call one of the functions - that defines a new `Operation`: - - ```python - c = tf.constant(4.0) - assert c.graph is tf.get_default_graph() - ``` - - Another typical usage involves the - `tf.Graph.as_default` - context manager, which overrides the current default graph for the - lifetime of the context: - - ```python - g = tf.Graph() - with g.as_default(): - # Define operations and tensors in `g`. - c = tf.constant(30.0) - assert c.graph is g - ``` - - Important note: This class *is not* thread-safe for graph construction. All - operations should be created from a single thread, or external - synchronization must be provided. Unless otherwise specified, all methods - are not thread-safe. - - A `Graph` instance supports an arbitrary number of "collections" - that are identified by name. For convenience when building a large - graph, collections can store groups of related objects: for - example, the `tf.Variable` uses a collection (named - `tf.GraphKeys.GLOBAL_VARIABLES`) for - all variables that are created during the construction of a graph. The caller - may define additional collections by specifying a new name. - */ + /// https://www.tensorflow.org/guide/graphs

https://www.tensorflow.org/api_docs/python/tf/Graph
public partial class Graph : DisposableObject, IEnumerable { private Dictionary _nodes_by_id; @@ -368,7 +369,7 @@ namespace Tensorflow var name_key = name.ToLower(); int i = 0; if (_names_in_use.ContainsKey(name_key)) - i = _names_in_use[name_key]; + i = _names_in_use[name_key]; // Increment the number for "name_key". if (mark_as_used) _names_in_use[name_key] = i + 1; @@ -398,13 +399,13 @@ namespace Tensorflow int num_return_outputs = 0; c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); TF_Output[] return_outputs = new TF_Output[num_return_outputs]; - for (int i = 0; i < num_return_outputs; i++) + unsafe { - var handle = return_output_handle + (Marshal.SizeOf() * i); - return_outputs[i] = Marshal.PtrToStructure(handle); + var tf_output_ptr = (TF_Output*) return_output_handle; + for (int i = 0; i < num_return_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); + return return_outputs; } - - return return_outputs; } public string[] get_all_collection_keys() @@ -439,12 +440,12 @@ namespace Tensorflow _unfetchable_ops.Add(op); } - protected override void DisposeManagedState() + protected override void DisposeManagedResources() { ops.default_graph_stack.remove(this); } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) { c_api.TF_DeleteGraph(handle); } @@ -496,11 +497,9 @@ namespace Tensorflow IEnumerator IEnumerable.GetEnumerator() => GetEnumerable().GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() - { - throw new NotImplementedException(); - } - + IEnumerator IEnumerable.GetEnumerator() + => throw new NotImplementedException(); + public static implicit operator IntPtr(Graph graph) { return graph._handle; diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs index 97720206..70802597 100644 --- a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs +++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs @@ -20,7 +20,8 @@ namespace Tensorflow { public class ImportGraphDefOptions : DisposableObject { - public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); + public int NumReturnOutputs + => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); public ImportGraphDefOptions() { @@ -37,7 +38,7 @@ namespace Tensorflow c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) => c_api.TF_DeleteImportGraphDefOptions(handle); public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; diff --git a/src/TensorFlowNET.Core/IO/gfile.cs b/src/TensorFlowNET.Core/IO/gfile.cs index 930dd652..a7303bf6 100644 --- a/src/TensorFlowNET.Core/IO/gfile.cs +++ b/src/TensorFlowNET.Core/IO/gfile.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using System.IO; +using System.Linq; namespace Tensorflow.IO { @@ -28,6 +29,9 @@ namespace Tensorflow.IO /// Traverse in order if True, post order if False. public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) { + if (!Directory.Exists(top)) + return Enumerable.Empty<(string, string[], string[])>(); + return walk_v2(top, in_order); } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 304e7f7b..a3ae3356 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -81,7 +81,7 @@ namespace Tensorflow.Layers // Update global default collections. - _add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS }); + _add_elements_to_collection(_updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); return outputs; } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index b1503567..aa314efb 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -152,9 +152,9 @@ namespace Tensorflow.Operations public (T, Tensor) BuildCondBranch(Func fn) { // Add the subgraph defined by fn() to the graph. - var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); var original_result = fn(); - var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); //TODO: port this chunck of missing code: /* @@ -191,9 +191,9 @@ namespace Tensorflow.Operations public (T[], Tensor[]) BuildCondBranch(Func fn) { // Add the subgraph defined by fn() to the graph. - var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); var original_result = fn(); - var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); switch (original_result) { diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 2c05e36a..2a76c52c 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -141,7 +141,7 @@ namespace Tensorflow.Operations data, frame_name, is_constant, parallel_iterations, name: name); if (use_input_shape) - result.SetShape(data.TensorShape); + result.set_shape(data.TensorShape); return result; } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index ccd88480..1faaa647 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -195,7 +195,7 @@ namespace Tensorflow.Operations // their associated TensorArrays for calling the body. var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); var body_result = body(packed_vars_for_body[0]); - var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); // Store body_result to keep track of TensorArrays returned by body var original_body_result = new[] { body_result }; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs new file mode 100644 index 00000000..f553d45b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -0,0 +1,59 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Initializers +{ + public class RandomNormal : IInitializer + { + private float mean; + private float stddev; + private int? seed; + private TF_DataType dtype; + + public RandomNormal(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.mean = mean; + this.stddev = stddev; + this.seed = seed; + this.dtype = dtype; + } + + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + { + if (dtype == TF_DataType.DtInvalid) + dtype = this.dtype; + return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed); + } + + public object get_config() + { + return new + { + mean, + stddev, + seed, + dtype + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Losses/Util.cs b/src/TensorFlowNET.Core/Operations/Losses/Util.cs index 71b3ed62..fde5bcb0 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/Util.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/Util.cs @@ -2,7 +2,7 @@ { public class Util { - public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES) + public static void add_loss(Tensor loss, string loss_collection = "losses") { if (!string.IsNullOrEmpty(loss_collection)) ops.add_to_collection(loss_collection, loss); diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs index de4bf964..1f4ce2d8 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -22,7 +22,7 @@ namespace Tensorflow public class LossesImpl { public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, - string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) + string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) { return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate { @@ -101,7 +101,7 @@ namespace Tensorflow Tensor logits, float weights = 1.0f, string scope = null, - string loss_collection= ops.GraphKeys.LOSSES, + string loss_collection= "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) { return tf_with(ops.name_scope(scope, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 6c31eb62..49d504ab 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -181,6 +181,31 @@ namespace Tensorflow.Operations return _op.outputs; } + /// + /// Local Response Normalization. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor local_response_normalization(Tensor input, int depth_radius = 5, int bias = 1, + int alpha = 1, float beta = 0.5f, string name = null) + { + var _op = _op_def_lib._apply_op_helper("LRN", name: name, args: new + { + input, + depth_radius, + bias, + alpha, + beta + }); + + return _op.output; + } + public static Tensor log_softmax(Tensor logits, string name = null) { var _op = _op_def_lib._apply_op_helper("LogSoftmax", name: name, args: new @@ -189,6 +214,17 @@ namespace Tensorflow.Operations }); return _op.outputs[0]; + } + + public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) + { + var _op = _op_def_lib._apply_op_helper("LeakyRelu", name: name, args: new + { + features, + alpha + }); + + return _op.output; } public static Tensor max_pool(Tensor input, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 3198942b..1b68d1cd 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -233,7 +233,7 @@ namespace Tensorflow.Operations dims.AddRange(x_static_shape.dims.Skip(2)); var shape = new TensorShape(dims.ToArray()); - x_t.SetShape(shape); + x_t.set_shape(shape); return x_t; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 24348322..62c8f378 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -50,14 +50,12 @@ namespace Tensorflow public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { - int size = Marshal.SizeOf(); - var handle = Marshal.AllocHGlobal(size); + var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); var consumers = new TF_Input[num]; + var inputptr = (TF_Input*) handle; for (int i = 0; i < num; i++) - { - consumers[i] = Marshal.PtrToStructure(handle + i * size); - } + consumers[i] = *(inputptr + i); return consumers; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 059290f4..5fff9ade 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -17,7 +17,9 @@ using Google.Protobuf.Collections; using System; using System.Collections.Generic; +using System.IO; using System.Linq; +using Tensorflow.Util; namespace Tensorflow { @@ -226,9 +228,12 @@ namespace Tensorflow using (var status = new Status()) using (var buf = new Buffer()) { - c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); - status.Check(true); - x = AttrValue.Parser.ParseFrom(buf); + unsafe + { + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); + status.Check(true); + x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); + } } string oneof_value = x.ValueCase.ToString(); @@ -259,7 +264,7 @@ namespace Tensorflow { c_api.TF_OperationToNodeDef(_handle, buffer, s); s.Check(); - return NodeDef.Parser.ParseFrom(buffer); + return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } @@ -299,8 +304,7 @@ namespace Tensorflow ///
public TF_Output _tf_output(int output_idx) { - var tf_output = new TF_Output(op, output_idx); - return tf_output; + return new TF_Output(op, output_idx); } /// @@ -308,8 +312,7 @@ namespace Tensorflow /// public TF_Input _tf_input(int input_idx) { - var tf_input = new TF_Input(op, input_idx); - return tf_input; + return new TF_Input(op, input_idx); } } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index d3213250..92f65906 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -260,8 +260,7 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "ones", new { dims }), scope => { name = scope; - var shape = ops.convert_to_tensor(dims, dtype: TF_DataType.TF_INT32); - var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); + var output = _constant_if_small(1, dims, dtype, name); return output; }); } @@ -351,7 +350,7 @@ namespace Tensorflow var input_shape = tensor_util.to_shape(input_tensor.shape); if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) { - var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); + var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); return constant_op.constant(nd, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 04595256..04ef54a7 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -431,8 +431,8 @@ namespace Tensorflow merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); return merges[0]; }); @@ -479,8 +479,8 @@ namespace Tensorflow merges = _convert_flows_to_tensorarrays(orig_res_t, merges); - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); return merges; }); @@ -596,7 +596,7 @@ namespace Tensorflow swap_memory: swap_memory); if (loop_context.outer_context == null) - ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); + ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, return_same_structure); diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs index baaf4fbe..90893815 100644 --- a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs @@ -23,7 +23,7 @@ namespace Tensorflow { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); - public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name= null) + public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name= null) { if (dtype == image.dtype) return array_ops.identity(image, name: name); @@ -57,7 +57,7 @@ namespace Tensorflow }); } - public Tensor decode_jpeg(Tensor contents, + public static Tensor decode_jpeg(Tensor contents, int channels = 0, int ratio = 1, bool fancy_upscaling = true, @@ -88,7 +88,70 @@ namespace Tensorflow } } - public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) + public static Tensor decode_gif(Tensor contents, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.context.executing_eagerly()) + { + throw new NotImplementedException("decode_gif"); + } + else + { + var _op = _op_def_lib._apply_op_helper("DecodeGif", name: name, args: new + { + contents + }); + + return _op.output; + } + } + + public static Tensor decode_png(Tensor contents, + int channels = 0, + TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.context.executing_eagerly()) + { + throw new NotImplementedException("decode_png"); + } + else + { + var _op = _op_def_lib._apply_op_helper("DecodePng", name: name, args: new + { + contents, + channels, + dtype + }); + + return _op.output; + } + } + + public static Tensor decode_bmp(Tensor contents, + int channels = 0, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.context.executing_eagerly()) + { + throw new NotImplementedException("decode_bmp"); + } + else + { + var _op = _op_def_lib._apply_op_helper("DecodeBmp", name: name, args: new + { + contents, + channels + }); + + return _op.output; + } + } + + public static Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) { if (tf.context.executing_eagerly()) { diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 5c3bcf72..c1257e19 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -141,7 +141,7 @@ namespace Tensorflow { var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); - return _op.outputs[0]; + return _op.output; } public static Tensor atan(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs index 8fe9ba71..06ae70a3 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs @@ -40,7 +40,7 @@ namespace Tensorflow name: name, args: new { shape, dtype, seed, seed2 }); - return _op.outputs[0]; + return _op.output; } /// diff --git a/src/TensorFlowNET.Core/Operations/gen_string_ops.cs b/src/TensorFlowNET.Core/Operations/gen_string_ops.cs new file mode 100644 index 00000000..87ac589e --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_string_ops.cs @@ -0,0 +1,42 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class gen_string_ops + { + static readonly OpDefLibrary _op_def_lib; + static gen_string_ops() { _op_def_lib = new OpDefLibrary(); } + + public static Tensor substr(Tensor input, int pos, int len, + string name = null, string @uint = "BYTE") + { + var _op = _op_def_lib._apply_op_helper("Substr", name: name, args: new + { + input, + pos, + len, + unit = @uint + }); + + return _op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs new file mode 100644 index 00000000..65ed8eb1 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -0,0 +1,132 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class image_ops_impl + { + public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null, bool expand_animations = true) + { + Tensor substr = null; + + Func _jpeg = () => + { + int jpeg_channels = channels; + var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels"); + string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'"; + var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); + return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate + { + return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype); + }); + }; + + Func _gif = () => + { + int gif_channels = channels; + var good_channels = math_ops.logical_and( + math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"), + math_ops.not_equal(gif_channels, 4, name: "check_gif_channels")); + + string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images"; + var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); + return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate + { + var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype); + if (!expand_animations) + // result = array_ops.gather(result, 0); + throw new NotImplementedException(""); + return result; + }); + }; + + Func _bmp = () => + { + int bmp_channels = channels; + var signature = string_ops.substr(contents, 0, 2); + var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); + string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; + var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); + var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels"); + string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images"; + var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); + return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate + { + return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype); + }); + }; + + Func _png = () => + { + return convert_image_dtype(gen_image_ops.decode_png( + contents, + channels, + dtype: dtype), + dtype); + }; + + Func check_gif = () => + { + var is_gif = math_ops.equal(substr, "\x47\x49\x46", name: "is_gif"); + return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif"); + }; + + Func check_png = () => + { + return control_flow_ops.cond(_is_png(contents), _png, check_gif, name: "cond_png"); + }; + + return tf_with(ops.name_scope(name, "decode_image"), scope => + { + substr = string_ops.substr(contents, 0, 3); + return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); + }); + } + + public static Tensor is_jpeg(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "is_jpeg"), scope => + { + var substr = string_ops.substr(contents, 0, 3); + return math_ops.equal(substr, "\xff\xd8\xff", name: name); + }); + } + + public static Tensor _is_png(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "is_png"), scope => + { + var substr = string_ops.substr(contents, 0, 3); + return math_ops.equal(substr, @"\211PN", name: name); + }); + } + + public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, + string name = null) + { + if (dtype == image.dtype) + return array_ops.identity(image, name: name); + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index fa1fda12..f5cfdb37 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -168,6 +168,9 @@ namespace Tensorflow public static Tensor multiply(Tx x, Ty y, string name = null) => gen_math_ops.mul(x, y, name: name); + public static Tensor not_equal(Tx x, Ty y, string name = null) + => gen_math_ops.not_equal(x, y, name: name); + public static Tensor mul_no_nan(Tx x, Ty y, string name = null) => gen_math_ops.mul_no_nan(x, y, name: name); @@ -264,6 +267,9 @@ namespace Tensorflow return gen_math_ops.log(x, name); } + public static Tensor logical_and(Tensor x, Tensor y, string name = null) + => gen_math_ops.logical_and(x, y, name: name); + public static Tensor lgamma(Tensor x, string name = null) => gen_math_ops.lgamma(x, name: name); diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index cbf55861..b189bb83 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -98,7 +98,7 @@ namespace Tensorflow // float to be selected, hence we use a >= comparison. var keep_mask = random_tensor >= rate; var ret = x * scale * math_ops.cast(keep_mask, x.dtype); - ret.SetShape(x.TensorShape); + ret.set_shape(x.TensorShape); return ret; }); } @@ -116,6 +116,19 @@ namespace Tensorflow return _softmax(logits, gen_nn_ops.log_softmax, axis, name); } + public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) + { + return tf_with(ops.name_scope(name, "LeakyRelu", new { features, alpha }), scope => + { + name = scope; + features = ops.convert_to_tensor(features, name: "features"); + if (features.dtype.is_integer()) + features = math_ops.cast(features, dtypes.float32); + return gen_nn_ops.leaky_relu(features, alpha: alpha, name: name); + //return math_ops.maximum(alpha * features, features, name: name); + }); + } + /// /// Performs the max pooling on the input. /// diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs index 9ca4db88..02e522bf 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -39,9 +39,10 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "random_normal", new { shape, mean, stddev }), scope => { + name = scope; var shape_tensor = _ShapeTensor(shape); var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); - var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev"); + var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev"); var (seed1, seed2) = random_seed.get_seed(seed); var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); var mul = rnd * stddev_tensor; diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs new file mode 100644 index 00000000..ee46cf78 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -0,0 +1,38 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class string_ops + { + /// + /// Return substrings from `Tensor` of strings. + /// + /// + /// + /// + /// + /// + /// + public static Tensor substr(Tensor input, int pos, int len, + string name = null, string @uint = "BYTE") + => gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); + } +} diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index efe2afd4..58177df2 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -1,408 +1,413 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using System.Numerics; -using System.Text; - -namespace Tensorflow -{ - public class BaseSession : DisposableObject - { - protected Graph _graph; - protected bool _opened; - protected bool _closed; - protected int _current_version; - protected byte[] _target; - public Graph graph => _graph; - - public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) - { - _graph = g is null ? ops.get_default_graph() : g; - _graph.as_default(); - _target = UTF8Encoding.UTF8.GetBytes(target); - - SessionOptions newOpts = null; - if (opts == null) - newOpts = new SessionOptions(); - - var status = new Status(); - - _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); - - // dispose newOpts - if (opts == null) - c_api.TF_DeleteSessionOptions(newOpts); - - status.Check(true); - } - - public virtual void run(Operation op, params FeedItem[] feed_dict) - { - _run(op, feed_dict); - } - - public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) - { - return _run(fetche, feed_dict)[0]; - } - - public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) - { - var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); - return (results[0], results[1], results[2], results[3]); - } - - public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) - { - var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); - return (results[0], results[1], results[2]); - } - - public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) - { - var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); - return (results[0], results[1]); - } - - public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) - { - return _run(fetches, feed_dict); - } - - public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) - { - var feed_items = feed_dict == null ? new FeedItem[0] : - feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); - return _run(fetches, feed_items); - } - - private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) - { - var feed_dict_tensor = new Dictionary(); - var feed_map = new Dictionary(); - - Func> feed_fn = (item) => - { - return new (object, object)[] { (item.Key, item.Value) }; - }; - - // Validate and process feed_dict. - if (feed_dict != null) - { - foreach (var feed in feed_dict) - { - foreach (var (subfeed, subfeed_val) in feed_fn(feed)) - { - var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); - //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used - feed_dict_tensor[subfeed_t] = subfeed_val; - feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); - } - } - } - - // Create a fetch handler to take care of the structure of fetches. - var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); - - // Run request and get response. - // We need to keep the returned movers alive for the following _do_run(). - // These movers are no longer needed when _do_run() completes, and - // are deleted when `movers` goes out of scope when this _run() ends. - var _ = _update_with_movers(); - var final_fetches = fetch_handler.fetches(); - var final_targets = fetch_handler.targets(); - - // We only want to really perform the run if fetches or targets are provided, - // or if the call is a partial run that specifies feeds. - var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); - - return fetch_handler.build_results(this, results); - } - - /// - /// Runs a step based on the given fetches and feeds. - /// - /// - /// A list of operations to be run, but not fetched. - /// - /// - /// - /// A list of numpy ndarrays, corresponding to the elements of - /// `fetch_list`. If the ith element of `fetch_list` contains the - /// name of an operation, the first Tensor output of that operation - /// will be returned for that element. - /// - private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) - { - var feeds = feed_dict.Select(x => - { - if (x.Key is Tensor tensor) - { - switch (x.Value) - { -#if _REGEN - %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] - %foreach types% - case #1 v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case #1[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - % -#else - case sbyte v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case sbyte[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case byte v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case byte[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case short v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case short[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ushort v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ushort[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case int v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case int[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case uint v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case uint[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case long v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case long[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ulong v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ulong[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case float v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case float[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case double v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case double[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Complex v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Complex[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); -#endif - case bool v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); - case string v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case IntPtr v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Tensor v: - return new KeyValuePair(tensor._as_tf_output(), v); - case NDArray v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); - default: - throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "")}"); - } - } - throw new NotImplementedException("_do_run.feed_dict"); - }).ToArray(); - var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); - var targets = target_list; - - return _call_tf_sessionrun(feeds, fetches, target_list); - } - - private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) - { - // Ensure any changes to the graph are reflected in the runtime. - _extend_graph(); - - var status = new Status(); - - var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); - - c_api.TF_SessionRun(_handle, - run_options: null, - inputs: feed_dict.Select(f => f.Key).ToArray(), - input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), - ninputs: feed_dict.Length, - outputs: fetch_list, - output_values: output_values, - noutputs: fetch_list.Length, - target_opers: target_list.Select(f => (IntPtr)f).ToArray(), - ntargets: target_list.Count, - run_metadata: IntPtr.Zero, - status: status); - - status.Check(true); - - var result = new NDArray[fetch_list.Length]; - - for (int i = 0; i < fetch_list.Length; i++) - result[i] = fetchValue(output_values[i]); - - for (int i = 0; i < feed_dict.Length; i++) - feed_dict[i].Value.Dispose(); - - return result; - } - - private unsafe NDArray fetchValue(IntPtr output) - { - var tensor = new Tensor(output); - NDArray nd = null; - Type type = tensor.dtype.as_numpy_datatype(); - var ndims = tensor.shape; - var offset = c_api.TF_TensorData(output); - - if(ndims.Length == 0) - { - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - nd = NDArray.Scalar(*(bool*)offset); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.Data(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = NDArray.FromString(str); - break; - case TF_DataType.TF_UINT8: - nd = NDArray.Scalar(*(byte*)offset); - break; - case TF_DataType.TF_INT16: - nd = NDArray.Scalar(*(short*)offset); - break; - case TF_DataType.TF_INT32: - nd = NDArray.Scalar(*(int*)offset); - break; - case TF_DataType.TF_INT64: - nd = NDArray.Scalar(*(long*)offset); - break; - case TF_DataType.TF_FLOAT: - nd = NDArray.Scalar(*(float*)offset); - break; - case TF_DataType.TF_DOUBLE: - nd = NDArray.Scalar(*(double*)offset); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - } - else - { - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - var bools = new bool[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(bools).reshape(ndims); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.Data(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = np.array(str); - break; - case TF_DataType.TF_UINT8: - var _bytes = new byte[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(_bytes).reshape(ndims); - break; - case TF_DataType.TF_INT16: - var shorts = new short[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(shorts).reshape(ndims); - break; - case TF_DataType.TF_INT32: - var ints = new int[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(ints).reshape(ndims); - break; - case TF_DataType.TF_INT64: - var longs = new long[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(longs).reshape(ndims); - break; - case TF_DataType.TF_FLOAT: - var floats = new float[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(floats).reshape(ndims); - break; - case TF_DataType.TF_DOUBLE: - var doubles = new double[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(doubles).reshape(ndims); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - } - - tensor.Dispose(); - - return nd; - } - - /// - /// If a tensor handle that is fed to a device incompatible placeholder, - /// we move the tensor to the right device, generate a new tensor handle, - /// and update feed_dict to use the new handle. - /// - private List _update_with_movers() - { - return new List { }; - } - - private void _extend_graph() - { - - } - - public void close() - { - Dispose(); - } - - protected override void DisposeUnManagedState(IntPtr handle) - { - using (var status = new Status()) - { - c_api.TF_DeleteSession(handle, status); - status.Check(true); - } - } - } -} +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Numerics; +using System.Text; + +namespace Tensorflow +{ + public class BaseSession : DisposableObject + { + protected Graph _graph; + protected bool _opened; + protected bool _closed; + protected int _current_version; + protected byte[] _target; + public Graph graph => _graph; + + public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) + { + _graph = g is null ? ops.get_default_graph() : g; + _graph.as_default(); + _target = UTF8Encoding.UTF8.GetBytes(target); + + SessionOptions newOpts = null; + if (opts == null) + newOpts = new SessionOptions(); + + var status = new Status(); + + _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); + + // dispose newOpts + if (opts == null) + newOpts.Dispose(); + + status.Check(true); + } + + public virtual void run(Operation op, params FeedItem[] feed_dict) + { + _run(op, feed_dict); + } + + public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) + { + return _run(fetche, feed_dict)[0]; + } + + public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) + { + return _run(fetche, feed_dict)[0]; + } + + public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); + return (results[0], results[1], results[2], results[3]); + } + + public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); + return (results[0], results[1], results[2]); + } + + public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); + return (results[0], results[1]); + } + + public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) + { + return _run(fetches, feed_dict); + } + + public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) + { + var feed_items = feed_dict == null ? new FeedItem[0] : + feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); + return _run(fetches, feed_items); + } + + private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) + { + var feed_dict_tensor = new Dictionary(); + var feed_map = new Dictionary(); + + Func> feed_fn = (item) => + { + return new (object, object)[] { (item.Key, item.Value) }; + }; + + // Validate and process feed_dict. + if (feed_dict != null) + { + foreach (var feed in feed_dict) + { + foreach (var (subfeed, subfeed_val) in feed_fn(feed)) + { + var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); + //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used + feed_dict_tensor[subfeed_t] = subfeed_val; + feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); + } + } + } + + // Create a fetch handler to take care of the structure of fetches. + var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); + + // Run request and get response. + // We need to keep the returned movers alive for the following _do_run(). + // These movers are no longer needed when _do_run() completes, and + // are deleted when `movers` goes out of scope when this _run() ends. + var _ = _update_with_movers(); + var final_fetches = fetch_handler.fetches(); + var final_targets = fetch_handler.targets(); + + // We only want to really perform the run if fetches or targets are provided, + // or if the call is a partial run that specifies feeds. + var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); + + return fetch_handler.build_results(this, results); + } + + /// + /// Runs a step based on the given fetches and feeds. + /// + /// + /// A list of operations to be run, but not fetched. + /// + /// + /// + /// A list of numpy ndarrays, corresponding to the elements of + /// `fetch_list`. If the ith element of `fetch_list` contains the + /// name of an operation, the first Tensor output of that operation + /// will be returned for that element. + /// + private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) + { + var feeds = feed_dict.Select(x => + { + if (x.Key is Tensor tensor) + { + switch (x.Value) + { +#if _REGEN + %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %foreach types% + case #1 v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case #1[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + % +#else + case sbyte v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case sbyte[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case byte v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case byte[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case short v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case short[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ushort v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ushort[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case int v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case int[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case uint v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case uint[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case long v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case long[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ulong v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ulong[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case float v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case float[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case double v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case double[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Complex v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Complex[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); +#endif + case bool v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); + case string v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case IntPtr v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Tensor v: + return new KeyValuePair(tensor._as_tf_output(), v); + case NDArray v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); + default: + throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "")}"); + } + } + throw new NotImplementedException("_do_run.feed_dict"); + }).ToArray(); + var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); + var targets = target_list; + + return _call_tf_sessionrun(feeds, fetches, target_list); + } + + private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) + { + // Ensure any changes to the graph are reflected in the runtime. + _extend_graph(); + + var status = new Status(); + + var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); + + c_api.TF_SessionRun(_handle, + run_options: null, + inputs: feed_dict.Select(f => f.Key).ToArray(), + input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), + ninputs: feed_dict.Length, + outputs: fetch_list, + output_values: output_values, + noutputs: fetch_list.Length, + target_opers: target_list.Select(f => (IntPtr)f).ToArray(), + ntargets: target_list.Count, + run_metadata: IntPtr.Zero, + status: status); + + status.Check(true); + + var result = new NDArray[fetch_list.Length]; + + for (int i = 0; i < fetch_list.Length; i++) + result[i] = fetchValue(output_values[i]); + + for (int i = 0; i < feed_dict.Length; i++) + feed_dict[i].Value.Dispose(); + + return result; + } + + private unsafe NDArray fetchValue(IntPtr output) + { + var tensor = new Tensor(output); + NDArray nd = null; + Type type = tensor.dtype.as_numpy_dtype(); + var ndims = tensor.shape; + var offset = c_api.TF_TensorData(output); + + if(ndims.Length == 0) + { + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + nd = NDArray.Scalar(*(bool*)offset); + break; + case TF_DataType.TF_STRING: + var bytes = tensor.BufferToArray(); + // wired, don't know why we have to start from offset 9. + // length in the begin + var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); + nd = NDArray.FromString(str); + break; + case TF_DataType.TF_UINT8: + nd = NDArray.Scalar(*(byte*)offset); + break; + case TF_DataType.TF_INT16: + nd = NDArray.Scalar(*(short*)offset); + break; + case TF_DataType.TF_INT32: + nd = NDArray.Scalar(*(int*)offset); + break; + case TF_DataType.TF_INT64: + nd = NDArray.Scalar(*(long*)offset); + break; + case TF_DataType.TF_FLOAT: + nd = NDArray.Scalar(*(float*)offset); + break; + case TF_DataType.TF_DOUBLE: + nd = NDArray.Scalar(*(double*)offset); + break; + default: + throw new NotImplementedException("can't fetch output"); + } + } + else + { + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + var bools = new bool[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(bools).reshape(ndims); + break; + case TF_DataType.TF_STRING: + var bytes = tensor.BufferToArray(); + // wired, don't know why we have to start from offset 9. + // length in the begin + var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); + nd = np.array(str); + break; + case TF_DataType.TF_UINT8: + var _bytes = new byte[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(_bytes).reshape(ndims); + break; + case TF_DataType.TF_INT16: + var shorts = new short[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(shorts).reshape(ndims); + break; + case TF_DataType.TF_INT32: + var ints = new int[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(ints).reshape(ndims); + break; + case TF_DataType.TF_INT64: + var longs = new long[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(longs).reshape(ndims); + break; + case TF_DataType.TF_FLOAT: + var floats = new float[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(floats).reshape(ndims); + break; + case TF_DataType.TF_DOUBLE: + var doubles = new double[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(doubles).reshape(ndims); + break; + default: + throw new NotImplementedException("can't fetch output"); + } + } + + tensor.Dispose(); + + return nd; + } + + /// + /// If a tensor handle that is fed to a device incompatible placeholder, + /// we move the tensor to the right device, generate a new tensor handle, + /// and update feed_dict to use the new handle. + /// + private List _update_with_movers() + { + return new List { }; + } + + private void _extend_graph() + { + + } + + public void close() + { + Dispose(); + } + + protected override void DisposeUnmanagedResources(IntPtr handle) + { + using (var status = new Status()) + { + c_api.TF_DeleteSession(handle, status); + status.Check(true); + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Sessions/FeedItem.cs b/src/TensorFlowNET.Core/Sessions/FeedItem.cs index f87457e7..c3a3dc67 100644 --- a/src/TensorFlowNET.Core/Sessions/FeedItem.cs +++ b/src/TensorFlowNET.Core/Sessions/FeedItem.cs @@ -16,5 +16,11 @@ public static implicit operator FeedItem((object, object) feed) => new FeedItem(feed.Item1, feed.Item2); + + public void Deconstruct(out object key, out object value) + { + key = Key; + value = Value; + } } } diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs index 8e0a0a74..112543fe 100644 --- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -32,13 +32,13 @@ namespace Tensorflow _handle = handle; } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) => c_api.TF_DeleteSessionOptions(handle); public void SetConfig(ConfigProto config) { - var bytes = config.ToByteArray(); - var proto = Marshal.AllocHGlobal(bytes.Length); + var bytes = config.ToByteArray(); //TODO! we can use WriteTo + var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak Marshal.Copy(bytes, 0, proto, bytes.Length); using (var status = new Status()) diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index a46decb1..e1a77d90 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -17,6 +17,7 @@ using NumSharp; using System; using System.Collections.Generic; +using NumSharp.Backends; namespace Tensorflow { @@ -71,18 +72,18 @@ namespace Tensorflow { if(tensor_values.Length > 0) { - switch (tensor_values[0].dtype.Name) + switch (tensor_values[0].typecode) { - case "Int32": + case NPTypeCode.Int32: full_values.Add(float.NaN); break; - case "Single": + case NPTypeCode.Single: full_values.Add(float.NaN); break; - case "String": + case NPTypeCode.String: full_values.Add(float.NaN); break; - case "Char": + case NPTypeCode.Char: full_values.Add(float.NaN); break; default: @@ -100,21 +101,21 @@ namespace Tensorflow j += 1; if (value.ndim == 0) { - switch (value.dtype.Name) + switch (value.typecode) { - case "Int16": + case NPTypeCode.Int16: full_values.Add(value.GetValue(0)); break; - case "Int32": + case NPTypeCode.Int32: full_values.Add(value.GetValue(0)); break; - case "Int64": + case NPTypeCode.Int64: full_values.Add(value.GetValue(0)); break; - case "Single": + case NPTypeCode.Single: full_values.Add(value.GetValue(0)); break; - case "Double": + case NPTypeCode.Double: full_values.Add(value.GetValue(0)); break; /*case "String": diff --git a/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs index c40b2a00..6cbf4eec 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs @@ -27,13 +27,17 @@ namespace Tensorflow var handle = Marshal.AllocHGlobal(size * num_consumers); int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); var consumers = new string[num_consumers]; - for (int i = 0; i < num; i++) + unsafe { - TF_Input input = Marshal.PtrToStructure(handle + i * size); - consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper)); + var inputptr = (TF_Input*) handle; + for (int i = 0; i < num; i++) + { + var oper = (inputptr + i)->oper; + consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); + } } return consumers; } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index 7eb2d7e3..ce561f75 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.Runtime.CompilerServices; +using static Tensorflow.c_api; namespace Tensorflow { @@ -27,36 +29,36 @@ namespace Tensorflow /// /// Error message /// - public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); + public string Message => c_api.StringPiece(TF_Message(_handle)); /// /// Error code /// - public TF_Code Code => c_api.TF_GetCode(_handle); + public TF_Code Code => TF_GetCode(_handle); public Status() { - _handle = c_api.TF_NewStatus(); + _handle = TF_NewStatus(); } public void SetStatus(TF_Code code, string msg) { - c_api.TF_SetStatus(_handle, code, msg); + TF_SetStatus(_handle, code, msg); } /// /// Check status /// Throw exception with error message if code != TF_OK /// + /// When the returned check is not TF_Code.TF_OK + [MethodImpl(MethodImplOptions.AggressiveInlining)] public void Check(bool throwException = false) { - if(Code != TF_Code.TF_OK) + if (Code != TF_Code.TF_OK) { Console.WriteLine(Message); if (throwException) - { - throw new Exception(Message); - } + throw new TensorflowException(Message); } } @@ -65,7 +67,7 @@ namespace Tensorflow return status._handle; } - protected override void DisposeUnManagedState(IntPtr handle) - => c_api.TF_DeleteStatus(handle); + protected override void DisposeUnmanagedResources(IntPtr handle) + => TF_DeleteStatus(handle); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Status/c_api.status.cs b/src/TensorFlowNET.Core/Status/c_api.status.cs index cfac49d1..ee17e447 100644 --- a/src/TensorFlowNET.Core/Status/c_api.status.cs +++ b/src/TensorFlowNET.Core/Status/c_api.status.cs @@ -51,7 +51,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_NewStatus(); + public static extern IntPtr TF_NewStatus(); /// /// Record in *s. Any previous information is lost. diff --git a/src/TensorFlowNET.Core/Summaries/Summary.cs b/src/TensorFlowNET.Core/Summaries/Summary.cs index 2bea0ddc..3d157bd9 100644 --- a/src/TensorFlowNET.Core/Summaries/Summary.cs +++ b/src/TensorFlowNET.Core/Summaries/Summary.cs @@ -33,11 +33,11 @@ namespace Tensorflow.Summaries { var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); - collect(val, collections?.ToList(), new List { ops.GraphKeys.SUMMARIES }); + collect(val, collections?.ToList(), new List { tf.GraphKeys.SUMMARIES }); return val; } - public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null) + public Tensor merge_all(string key = "summaries", string scope= null, string name= null) { var summary_ops = ops.get_collection(key, scope: scope); if (summary_ops == null) @@ -67,7 +67,7 @@ namespace Tensorflow.Summaries { var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); - collect(val, collections?.ToList(), new List { ops.GraphKeys.SUMMARIES }); + collect(val, collections?.ToList(), new List { tf.GraphKeys.SUMMARIES }); return val; } diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index abd6e1bf..bd8c0a29 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,8 +5,8 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.0 - Haiping Chen, Meinrad Recheis + 0.11.1 + Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true Apache 2.0 @@ -17,10 +17,16 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.11.10.0 - Changes since v0.10.0: + 0.11.1.0 + Changes since v0.10.0: +1. Upgrade NumSharp to v0.20. +2. Add DisposableObject class to manage object lifetime. +3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. +4. Change tensorflow to non-static class in order to execute some initialization process. +5. Overload session.run(), make syntax simpler. +6. Add Local Response Normalization. 7.3 - 0.11.10.0 + 0.11.1.0 LICENSE true true @@ -52,7 +58,7 @@ Docs: https://tensorflownet.readthedocs.io - + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 63fda866..625b424a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -16,11 +16,13 @@ using NumSharp; using System; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using NumSharp.Backends; using NumSharp.Backends.Unmanaged; using static Tensorflow.c_api; @@ -50,9 +52,9 @@ namespace Tensorflow private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; // note: they must be assigned to a static variable in order to work as unmanaged callbacks - static Deallocator _hGlobalDeallocator = FreeHGlobalMemory; - static Deallocator _gcHandleDeallocator = FreeGCHandle; - private static Deallocator _nothingDeallocator = FreeNothing; + private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory; + private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle; + private static readonly Deallocator _nothingDeallocator = FreeNothing; /// /// Create a Tensor object from an existing TF handle @@ -462,7 +464,7 @@ namespace Tensorflow *v = value; _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); IsMemoryOwner=true; - } + } #endif /// @@ -477,7 +479,7 @@ namespace Tensorflow IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); - fixed (byte* src = &buffer[0]) + fixed (byte* src = buffer) c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); _handle = handle; status.Check(true); @@ -486,35 +488,54 @@ namespace Tensorflow public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) { // todo: handle nd of type "String" here too - if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") + if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) { - var buffer = nd.ToArray(); - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); - - IntPtr tensor = c_api.TF_TensorData(handle); - Marshal.WriteInt64(tensor, 0); + if (nd.Unsafe.Storage.Shape.IsContiguous) + { + var bytesLength = (UIntPtr)nd.size; + var size = c_api.TF_StringEncodedSize(bytesLength); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); + + var status = new Status(); + c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); + + status.Check(true); + _handle = handle; + IsMemoryOwner = false; + } + else + { + var buffer = nd.ToArray(); + var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); + + var status = new Status(); + fixed (byte* src = buffer) + c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); + + status.Check(true); + _handle = handle; + IsMemoryOwner = false; + } - var status = new Status(); - fixed (byte* src = &buffer[0]) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); - - status.Check(true); - _handle=handle; - IsMemoryOwner = false; return; } + _handle = CreateTensorFromNDArray(nd, tensorDType); - IsMemoryOwner = true; } private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) { - if (nd.dtype.Name == "String") - throw new NotImplementedException("Support for NDArray of type string not implemented yet"); + if (nd.dtype.Name == "String") + throw new NotImplementedException("Support for NDArray of type string not implemented yet"); IArraySlice arraySlice; - var shape = nd.Unsafe.Storage.Shape; - if (shape.IsSliced || shape.IsBroadcasted) + if (nd.Unsafe.Storage.Shape.IsContiguous == false) { // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. arraySlice = nd.CloneData(); @@ -527,51 +548,52 @@ namespace Tensorflow this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it var ptr = new IntPtr(arraySlice.Address); int num_bytes = (nd.size * nd.dtypesize); - var dtype = given_dtype ?? ToTFDataType(nd.dtype); + var dtype = given_dtype ?? nd.dtype.as_dtype(); var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); IsMemoryOwner = false; return handle; - } - - public unsafe Tensor(byte[][] buffer, long[] shape) - { - int size = 0; - foreach (var b in buffer) - { - size += (int)TF_StringEncodedSize((UIntPtr)b.Length); - } - int totalSize = size + buffer.Length * 8; - ulong offset = 0; - IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); - - // Clear offset table - IntPtr pOffset = TF_TensorData(handle); - IntPtr dst = pOffset + buffer.Length * 8; - IntPtr dstLimit = pOffset + totalSize; - for (int i = 0; i < buffer.Length; i++) - { - Marshal.WriteInt64(pOffset, (long)offset); - using (var status = new Status()) - { - fixed (byte* src = &buffer[i][0]) - { - var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); - status.Check(true); - pOffset += 8; - dst += (int)written; - offset += written; - } - } - } - - _handle = handle; + + } + + public unsafe Tensor(byte[][] buffer, long[] shape) + { + int size = 0; + foreach (var b in buffer) + { + size += (int)TF_StringEncodedSize((UIntPtr)b.Length); + } + int totalSize = size + buffer.Length * 8; + ulong offset = 0; + IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); + + // Clear offset table + IntPtr pOffset = TF_TensorData(handle); + IntPtr dst = pOffset + buffer.Length * 8; + IntPtr dstLimit = pOffset + totalSize; + for (int i = 0; i < buffer.Length; i++) + { + Marshal.WriteInt64(pOffset, (long)offset); + using (var status = new Status()) + { + fixed (byte* src = &buffer[i][0]) + { + var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); + status.Check(true); + pOffset += 8; + dst += (int)written; + offset += written; + } + } + } + + _handle = handle; } public Tensor(Operation op, int value_index, TF_DataType dtype) { _op = op; _value_index = value_index; - _dtype = dtype; + _override_dtype = dtype; _id = ops.uid(); } @@ -589,11 +611,11 @@ namespace Tensorflow /// specified dimensions. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] + [SuppressMessage("ReSharper", "LocalVariableHidesMember")] protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) { - if (dt == TF_DataType.TF_STRING && data is byte[]) + if (dt == TF_DataType.TF_STRING && data is byte[] buffer) { - var buffer = (byte[])data; var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); @@ -601,7 +623,7 @@ namespace Tensorflow Marshal.WriteInt64(tensor, 0); var status = new Status(); - fixed (byte* src = &buffer[0]) + fixed (byte* src = buffer) c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); status.Check(true); @@ -644,8 +666,9 @@ namespace Tensorflow { if (args.deallocator_called) return; + // NumSharp will dispose - // Marshal.FreeHGlobal(dataPtr); + Marshal.FreeHGlobal(dataPtr); args.deallocator_called = true; } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs index 6db60b4a..6d7f20f1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs @@ -1,4 +1,5 @@ using System; +using System.Runtime.CompilerServices; namespace Tensorflow { @@ -6,86 +7,142 @@ namespace Tensorflow { public static explicit operator bool(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_BOOL); + return *(bool*) tensor.buffer; + } } public static explicit operator sbyte(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT8); + return *(sbyte*) tensor.buffer; + } } public static explicit operator byte(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT8); + return *(byte*) tensor.buffer; + } } public static explicit operator ushort(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT16); + return *(ushort*) tensor.buffer; + } } public static explicit operator short(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT16); + return *(short*) tensor.buffer; + } } public static explicit operator int(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT32); + return *(int*) tensor.buffer; + } } public static explicit operator uint(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT32); + return *(uint*) tensor.buffer; + } } public static explicit operator long(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT64); + return *(long*) tensor.buffer; + } } public static explicit operator ulong(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT64); + return *(ulong*) tensor.buffer; + } } public static explicit operator float(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_FLOAT); + return *(float*) tensor.buffer; + } } public static explicit operator double(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_DOUBLE); + return *(double*) tensor.buffer; + } + } + + public static explicit operator string(Tensor tensor) + { + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_STRING); + return new string((char*) tensor.buffer, 0, (int) tensor.size); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void EnsureDType(Tensor tensor, TF_DataType @is) + { + if (tensor.dtype != @is) + throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}"); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void EnsureScalar(Tensor tensor) { if (tensor == null) - { throw new ArgumentNullException(nameof(tensor)); - } if (tensor.TensorShape.ndim != 0) - { throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); - } if (tensor.TensorShape.size != 1) - { throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); - } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index eb912eb9..4b15864f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -69,11 +69,12 @@ namespace Tensorflow TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 }; + public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); public static Tensor operator /(Tensor x, Tensor y) => - _intTfDataTypes.Contains(x._dtype) + _intTfDataTypes.Contains(x.dtype) ? BinaryOpWrapper("floordiv", x, y) : BinaryOpWrapper("truediv", x, y); public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); @@ -122,8 +123,7 @@ namespace Tensorflow if (y is Tensor tr) dtype = tr.dtype.as_base_dtype(); - var namescope = ops.name_scope(null, name, new { x, y }); - return tf_with(namescope, scope => + return tf_with(ops.name_scope(null, name, new { x, y }), scope => { Tensor result = null; var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); @@ -155,7 +155,6 @@ namespace Tensorflow return result; }); - } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index d52b9422..75cba69e 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -17,9 +17,16 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Linq; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using System.Threading.Tasks; +using NumSharp.Backends; +using NumSharp.Backends.Unmanaged; +using NumSharp.Utilities; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -29,42 +36,68 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// + [SuppressMessage("ReSharper", "ConvertToAutoProperty")] public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike { - private int _id; - private Operation _op; + private readonly int _id; + private readonly Operation _op; + private readonly int _value_index; + private TF_Output? _tf_output; + private readonly TF_DataType _override_dtype; public int Id => _id; + + /// + /// The Graph that contains this tensor. + /// public Graph graph => op?.graph; + + /// + /// The Operation that produces this tensor as an output. + /// public Operation op => _op; + public Tensor[] outputs => op.outputs; /// - /// The string name of this tensor. + /// The string name of this tensor. /// public string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}"; - private int _value_index; + /// + /// The index of this tensor in the outputs of its Operation. + /// public int value_index => _value_index; - private TF_DataType _dtype = TF_DataType.DtInvalid; - public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); + /// + /// The DType of elements in this tensor. + /// + public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); - public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; - public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); + public int NDims => rank; - private TF_Output? _tf_output; + /// + /// The name of the device on which this tensor will be produced, or null. + /// + public string Device => op.Device; + + public int[] dims => shape; /// - /// used for keep other pointer when do implicit operating + /// Used for keep other pointer when do implicit operating /// public object Tag { get; set; } + + /// + /// Returns the shape of a tensor. + /// + /// https://www.tensorflow.org/api_docs/python/tf/shape public int[] shape { get @@ -76,14 +109,13 @@ namespace Tensorflow var status = new Status(); c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); status.Check(); - } - else + } else { for (int i = 0; i < rank; i++) dims[i] = c_api.TF_Dim(_handle, i); } - return dims.Select(x => Convert.ToInt32(x)).ToArray(); + return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); } set @@ -93,38 +125,52 @@ namespace Tensorflow if (value == null) c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); else - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status); + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); } } public int[] _shape_tuple() { - if (shape == null) return null; - return shape.Select(x => (int)x).ToArray(); + return (int[]) shape.Clone(); } public TensorShape TensorShape => tensor_util.to_shape(shape); - public void SetShape(TensorShape shape) + /// + /// Updates the shape of this tensor. + /// + public void set_shape(TensorShape shape) { - this.shape = shape.dims; + this.shape = (int[]) shape.dims.Clone(); } + /// + /// Updates the shape of this tensor. + /// + [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] + public void SetShape(TensorShape shape) + { + this.shape = (int[]) shape.dims.Clone(); + } + + /// + /// Updates the shape of this tensor. + /// public void set_shape(Tensor shape) { + // ReSharper disable once MergeConditionalExpression this.shape = shape is null ? null : shape.shape; } - public int[] dims => shape; - /// - /// number of dimensions - /// 0 Scalar (magnitude only) - /// 1 Vector (magnitude and direction) - /// 2 Matrix (table of numbers) - /// 3 3-Tensor (cube of numbers) + /// number of dimensions

+ /// 0 Scalar (magnitude only)

+ /// 1 Vector (magnitude and direction)

+ /// 2 Matrix (table of numbers)

+ /// 3 3-Tensor (cube of numbers)

/// n n-Tensor (you get the idea) ///
+ /// https://www.tensorflow.org/api_docs/python/tf/rank public int rank { get @@ -137,17 +183,15 @@ namespace Tensorflow status.Check(); return ndim; } - else - { - return c_api.TF_NumDims(_handle); - } + + return c_api.TF_NumDims(_handle); } } - public int NDims => rank; - - public string Device => op.Device; - + /// + /// Returns a list of Operations that consume this tensor. + /// + /// public Operation[] consumers() { var output = _as_tf_output(); @@ -157,38 +201,181 @@ namespace Tensorflow public TF_Output _as_tf_output() { - if(!_tf_output.HasValue) + if (!_tf_output.HasValue) _tf_output = new TF_Output(op, value_index); return _tf_output.Value; } - public T[] Data() + [Obsolete("Please use ToArray() instead.", false)] + public T[] Data() where T : unmanaged { - // Column major order - // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg - // matrix:[[1, 2, 3], [4, 5, 6]] - // index: 0 2 4 1 3 5 - // result: 1 4 2 5 3 6 - var data = new T[size]; - - for (ulong i = 0; i < size; i++) + return ToArray(); + } + + /// + /// + /// + /// + /// + /// When is string + public T[] ToArray() where T : unmanaged + { + //Are the types matching? + if (typeof(T).as_dtype() == dtype) { - data[i] = Marshal.PtrToStructure(buffer + (int)(i * itemsize)); - } + if (NDims == 0 && size == 1) //is it a scalar? + { + unsafe + { + return new T[] {*(T*) buffer}; + } + } + + //types match, no need to perform cast + var ret = new T[size]; + unsafe + { + var len = (long) size; + fixed (T* dst = ret) + { + //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. + var src = (T*) buffer; + len *= ((long) itemsize); + System.Buffer.MemoryCopy(src, dst, len, len); + } + } + + return ret; + } else + { + + //types do not match, need to perform cast + if (NDims == 0 && size == 1) //is it a scalar? + { + unsafe + { +#if _REGEN + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer, NPTypeCode.#1)}; + % + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; + default: + throw new NotSupportedException(); + } + #endregion +#else + #region Compute + switch (dtype.as_numpy_dtype()?.GetTypeCode()) + { + case NPTypeCode.Boolean: return new T[] {Converts.ChangeType(*(bool*) buffer, NPTypeCode.Boolean)}; + case NPTypeCode.Byte: return new T[] {Converts.ChangeType(*(byte*) buffer, NPTypeCode.Byte)}; + case NPTypeCode.Int16: return new T[] {Converts.ChangeType(*(short*) buffer, NPTypeCode.Int16)}; + case NPTypeCode.UInt16: return new T[] {Converts.ChangeType(*(ushort*) buffer, NPTypeCode.UInt16)}; + case NPTypeCode.Int32: return new T[] {Converts.ChangeType(*(int*) buffer, NPTypeCode.Int32)}; + case NPTypeCode.UInt32: return new T[] {Converts.ChangeType(*(uint*) buffer, NPTypeCode.UInt32)}; + case NPTypeCode.Int64: return new T[] {Converts.ChangeType(*(long*) buffer, NPTypeCode.Int64)}; + case NPTypeCode.UInt64: return new T[] {Converts.ChangeType(*(ulong*) buffer, NPTypeCode.UInt64)}; + case NPTypeCode.Char: return new T[] {Converts.ChangeType(*(char*) buffer, NPTypeCode.Char)}; + case NPTypeCode.Double: return new T[] {Converts.ChangeType(*(double*) buffer, NPTypeCode.Double)}; + case NPTypeCode.Single: return new T[] {Converts.ChangeType(*(float*) buffer, NPTypeCode.Single)}; + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; + default: + throw new NotSupportedException(); + } + #endregion +#endif + } + } - return data; + var ret = new T[size]; + unsafe + { + var len = (long) size; + fixed (T* dstRet = ret) + { + T* dst = dstRet; //local stack copy + +#if _REGEN + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + % + default: + throw new NotSupportedException(); + } + #endregion +#else + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + case NPTypeCode.Boolean: new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Byte: new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int16: new UnmanagedMemoryBlock((short*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt16: new UnmanagedMemoryBlock((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int32: new UnmanagedMemoryBlock((int*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt32: new UnmanagedMemoryBlock((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int64: new UnmanagedMemoryBlock((long*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt64: new UnmanagedMemoryBlock((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Char: new UnmanagedMemoryBlock((char*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Double: new UnmanagedMemoryBlock((double*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Single: new UnmanagedMemoryBlock((float*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To + default: + throw new NotSupportedException(); + } + #endregion +#endif + + } + } + + return ret; + } } + /// + /// Copies the memory of current buffer onto newly allocated array. + /// + /// + [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] public byte[] Data() { - var data = new byte[bytesize]; - Marshal.Copy(buffer, data, 0, (int)bytesize); - return data; + return BufferToArray(); + } + + /// + /// Copies the memory of current buffer onto newly allocated array. + /// + /// + public byte[] BufferToArray() + { + unsafe + { + // ReSharper disable once LocalVariableHidesMember + var bytesize = (long) this.bytesize; + var data = new byte[bytesize]; + fixed (byte* dst = data) + System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); + + return data; + } } + /// + /// Extracts string array from current Tensor. + /// + /// When != TF_DataType.TF_STRING public unsafe string[] StringData() { + if (dtype != TF_DataType.TF_STRING) + throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); + // // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] @@ -199,19 +386,19 @@ namespace Tensorflow var buffer = new byte[size][]; var src = c_api.TF_TensorData(_handle); - var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); - src += (int)(size * 8); + var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize); + src += (int) (size * 8); for (int i = 0; i < buffer.Length; i++) { using (var status = new Status()) { IntPtr dst = IntPtr.Zero; UIntPtr dstLen = UIntPtr.Zero; - var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); + var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status); status.Check(true); - buffer[i] = new byte[(int)dstLen]; + buffer[i] = new byte[(int) dstLen]; Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); - src += (int)read; + src += (int) read; } } @@ -229,51 +416,29 @@ namespace Tensorflow } /// - /// Evaluates this tensor in a `Session`. + /// Evaluates this tensor in a `Session`. /// /// A dictionary that maps `Tensor` objects to feed values. - /// The `Session` to be used to evaluate this tensor. - /// + /// A array corresponding to the value of this tensor. public NDArray eval(params FeedItem[] feed_dict) { return ops._eval_using_default_session(this, feed_dict, graph); } - public NDArray eval(Session session, FeedItem[] feed_dict = null) + /// + /// Evaluates this tensor in a `Session`. + /// + /// A dictionary that maps `Tensor` objects to feed values. + /// The `Session` to be used to evaluate this tensor. + /// A array corresponding to the value of this tensor. + public NDArray eval(Session session, params FeedItem[] feed_dict) { return ops._eval_using_default_session(this, feed_dict, graph, session); } - public TF_DataType ToTFDataType(Type type) - { - switch (type.Name) - { - case "Char": - return TF_DataType.TF_UINT8; - case "Int16": - return TF_DataType.TF_INT16; - case "Int32": - return TF_DataType.TF_INT32; - case "Int64": - return TF_DataType.TF_INT64; - case "Single": - return TF_DataType.TF_FLOAT; - case "Double": - return TF_DataType.TF_DOUBLE; - case "Byte": - return TF_DataType.TF_UINT8; - case "String": - return TF_DataType.TF_STRING; - case "Boolean": - return TF_DataType.TF_BOOL; - default: - throw new NotImplementedException("ToTFDataType error"); - } - } - public Tensor slice(Slice slice) { - var slice_spec = new int[] { slice.Start.Value }; + var slice_spec = new int[] {slice.Start.Value}; var begin = new List(); var end = new List(); var strides = new List(); @@ -289,26 +454,26 @@ namespace Tensorflow if (slice.Stop.HasValue) { end.Add(slice.Stop.Value); - } - else + } else { end.Add(0); end_mask |= (1 << index); } + strides.Add(slice.Step); index += 1; } - return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => { string name = scope; if (begin != null) { var (packed_begin, packed_end, packed_strides) = (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); return gen_array_ops.strided_slice( this, @@ -320,7 +485,6 @@ namespace Tensorflow shrink_axis_mask: shrink_axis_mask, new_axis_mask: new_axis_mask, ellipsis_mask: ellipsis_mask, - name: name); } @@ -330,7 +494,7 @@ namespace Tensorflow public Tensor slice(int start) { - var slice_spec = new int[] { start }; + var slice_spec = new int[] {start}; var begin = new List(); var end = new List(); var strides = new List(); @@ -349,15 +513,15 @@ namespace Tensorflow index += 1; } - return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => { string name = scope; if (begin != null) { var (packed_begin, packed_end, packed_strides) = (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); return gen_array_ops.strided_slice( this, @@ -369,7 +533,6 @@ namespace Tensorflow shrink_axis_mask: shrink_axis_mask, new_axis_mask: new_axis_mask, ellipsis_mask: ellipsis_mask, - name: name); } @@ -392,29 +555,13 @@ namespace Tensorflow return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; } - protected override void DisposeManagedState() + protected override void DisposeUnmanagedResources(IntPtr handle) { + c_api.TF_DeleteTensor(handle); } - protected override void DisposeUnManagedState(IntPtr handle) - { - if(handle != IntPtr.Zero) - { - c_api.TF_DeleteTensor(handle); - } - } - - public bool IsDisposed - { - get - { - lock (this) - { - return _handle == IntPtr.Zero; - } - } - } + public bool IsDisposed => _disposed; public int tensor_int_val { get; set; } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 13258f79..8c9e571e 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -1,35 +1,84 @@ using NumSharp; using System; +using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Runtime.CompilerServices; namespace Tensorflow { /// - /// Represents the shape of a `Tensor`. + /// Represents the shape of a `Tensor`. /// + /// https://www.tensorflow.org/api_docs/python/tf/TensorShape public class TensorShape { - private Shape shape; + private readonly Shape shape; + + /// + /// Returns a list of Dimensions, or None if the shape is unspecified. + /// public int[] dims => shape.Dimensions; + + /// + /// Returns the rank of this shape. + /// public int ndim => shape.NDim; + + /// + /// Returns the rank of this shape. + /// + public int rank => shape.NDim; + + /// + /// Returns the size this shape represents. + /// public int size => shape.Size; public TensorShape(TensorShapeProto proto) { if (proto.UnknownRank) return; + switch (proto.Dim.Count) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; + case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; + default: + var protodims = proto.Dim; + var len = protodims.Count; + var dims = new int[len]; + for (int i = 0; i < len; i++) + dims[i] = (int) protodims[i].Size; - shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); + + shape = new Shape(dims); break; + } } public TensorShape(params int[] dims) { - shape = new Shape(dims); + switch (dims.Length) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int) dims[0]); break; + case 2: shape = Shape.Matrix(dims[0], dims[1]); break; + default: shape = new Shape(dims); break; + } } + /// + /// + /// + /// + /// + /// When is not an Index. + [SuppressMessage("ReSharper", "PossibleInvalidOperationException")] public TensorShape this[Slice slice] { get { + if (slice.Start.HasValue == false || slice.Length.HasValue == false) + throw new ArgumentException("Slice must has Start and Length."); + return new TensorShape(dims.Skip(slice.Start.Value) .Take(slice.Length.Value) .ToArray()); @@ -37,7 +86,7 @@ namespace Tensorflow } /// - /// Returns True iff `self` is fully defined in every dimension. + /// Returns True iff `self` is fully defined in every dimension. /// /// public bool is_fully_defined() @@ -50,6 +99,7 @@ namespace Tensorflow throw new NotImplementedException("TensorShape is_compatible_with"); } + [SuppressMessage("ReSharper", "ParameterHidesMember")] public TensorShape with_rank_at_least(int rank) { if (rank != ndim) @@ -59,35 +109,68 @@ namespace Tensorflow } /// - /// Returns the concatenation of the dimension in `self` and `other`. + /// Returns the concatenation of the dimension in `self` and `other`. + /// + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TensorShape concatenate(int[] other) + { + return concatenate(new TensorShape(other)); + } + + /// + /// Returns the concatenation of the dimension in `self` and `other`. /// /// /// - public TensorShape concatenate(int[] other_) + public TensorShape concatenate(TensorShape other) { - var other = new TensorShape(other_); + var otherShape = other; - if (ndim < 0 || other.ndim < 0) + if (ndim < 0 || otherShape.ndim < 0) return new TensorShape(); else { - var concatenate_dims = new int[ndim + other.ndim]; + var concatenate_dims = new int[ndim + otherShape.ndim]; for (int i = 0; i < ndim; i++) concatenate_dims[i] = dims[i]; - for (int i = 0; i < other.ndim; i++) - concatenate_dims[ndim + i] = other.dims[i]; + for (int i = 0; i < otherShape.ndim; i++) + concatenate_dims[ndim + i] = otherShape.dims[i]; return new TensorShape(concatenate_dims); } } - public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); - public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); + public override string ToString() + { + return shape.ToString(); + } + + public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); + public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); + + public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); - public static implicit operator int[](TensorShape shape) => shape.dims; + + public static explicit operator int(TensorShape shape) => shape.size; + public static explicit operator TensorShape(int dim) => new TensorShape(dim); + + public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); + + public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); + + public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + + public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); + + public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); + } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 807dc6f5..37f1ca61 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.Numerics; +using NumSharp.Backends; namespace Tensorflow { @@ -23,35 +25,100 @@ namespace Tensorflow public static TF_DataType int8 = TF_DataType.TF_INT8; public static TF_DataType int32 = TF_DataType.TF_INT32; public static TF_DataType int64 = TF_DataType.TF_INT64; + public static TF_DataType uint8 = TF_DataType.TF_UINT8; + public static TF_DataType uint32 = TF_DataType.TF_UINT32; + public static TF_DataType uint64 = TF_DataType.TF_UINT64; public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; - public static Type as_numpy_datatype(this TF_DataType type) + /// + /// + /// + /// + /// equivalent to , if none exists, returns null. + public static Type as_numpy_dtype(this TF_DataType type) { switch (type) { case TF_DataType.TF_BOOL: return typeof(bool); + case TF_DataType.TF_UINT8: + return typeof(byte); case TF_DataType.TF_INT64: return typeof(long); + case TF_DataType.TF_UINT64: + return typeof(ulong); case TF_DataType.TF_INT32: return typeof(int); + case TF_DataType.TF_UINT32: + return typeof(uint); case TF_DataType.TF_INT16: return typeof(short); + case TF_DataType.TF_UINT16: + return typeof(ushort); case TF_DataType.TF_FLOAT: return typeof(float); case TF_DataType.TF_DOUBLE: return typeof(double); case TF_DataType.TF_STRING: return typeof(string); + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX + return typeof(Complex); default: return null; } } - // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" - public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) + /// + /// + /// + /// + /// + /// When has no equivalent + public static NPTypeCode as_numpy_typecode(this TF_DataType type) + { + switch (type) + { + case TF_DataType.TF_BOOL: + return NPTypeCode.Boolean; + case TF_DataType.TF_UINT8: + return NPTypeCode.Byte; + case TF_DataType.TF_INT64: + return NPTypeCode.Int64; + case TF_DataType.TF_INT32: + return NPTypeCode.Int32; + case TF_DataType.TF_INT16: + return NPTypeCode.Int16; + case TF_DataType.TF_UINT64: + return NPTypeCode.UInt64; + case TF_DataType.TF_UINT32: + return NPTypeCode.UInt32; + case TF_DataType.TF_UINT16: + return NPTypeCode.UInt16; + case TF_DataType.TF_FLOAT: + return NPTypeCode.Single; + case TF_DataType.TF_DOUBLE: + return NPTypeCode.Double; + case TF_DataType.TF_STRING: + return NPTypeCode.String; + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX + return NPTypeCode.Complex; + default: + throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); + } + } + + /// + /// + /// + /// + /// + /// + /// When has no equivalent + public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null) { switch (type.Name) { @@ -98,7 +165,7 @@ namespace Tensorflow dtype = TF_DataType.TF_BOOL; break; default: - throw new Exception("as_dtype Not Implemented"); + throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); } return dtype.Value; @@ -106,16 +173,7 @@ namespace Tensorflow public static DataType as_datatype_enum(this TF_DataType type) { - DataType dtype = DataType.DtInvalid; - - switch (type) - { - default: - Enum.TryParse(((int)type).ToString(), out dtype); - break; - } - - return dtype; + return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; } public static TF_DataType as_base_dtype(this TF_DataType type) @@ -132,7 +190,7 @@ namespace Tensorflow public static Type as_numpy_dtype(this DataType type) { - return type.as_tf_dtype().as_numpy_datatype(); + return type.as_tf_dtype().as_numpy_dtype(); } public static DataType as_base_dtype(this DataType type) @@ -144,16 +202,7 @@ namespace Tensorflow public static TF_DataType as_tf_dtype(this DataType type) { - TF_DataType dtype = TF_DataType.DtInvalid; - - switch (type) - { - default: - Enum.TryParse(((int)type).ToString(), out dtype); - break; - } - - return dtype; + return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; } public static TF_DataType as_ref(this TF_DataType type) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index ded105c7..59c107fc 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -17,6 +17,7 @@ using NumSharp; using System; using System.Linq; +using NumSharp.Utilities; namespace Tensorflow { @@ -82,6 +83,12 @@ namespace Tensorflow throw new NotImplementedException("MakeNdarray"); } + private static readonly TF_DataType[] quantized_types = new TF_DataType[] + { + TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, + TF_DataType.TF_QINT32 + }; + /// /// Create a TensorProto. /// @@ -98,18 +105,9 @@ namespace Tensorflow if (values is TensorProto tp) return tp; - if (dtype != TF_DataType.DtInvalid) - ; - - bool is_quantized = new TF_DataType[] - { - TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, - TF_DataType.TF_QINT32 - }.Contains(dtype); - // We first convert value to a numpy array or scalar. NDArray nparray = null; - var np_dt = dtype.as_numpy_datatype(); + var np_dt = dtype.as_numpy_dtype(); if (values is NDArray nd) { @@ -188,37 +186,37 @@ namespace Tensorflow if (values.GetType().IsArray) nparray = np.array((int[])values, np_dt); else - nparray = Convert.ToInt32(values); + nparray = Converts.ToInt32(values); break; case "Int64": if (values.GetType().IsArray) nparray = np.array((int[])values, np_dt); else - nparray = Convert.ToInt64(values); + nparray = Converts.ToInt64(values); break; case "Single": if (values.GetType().IsArray) nparray = np.array((float[])values, np_dt); else - nparray = Convert.ToSingle(values); + nparray = Converts.ToSingle(values); break; case "Double": if (values.GetType().IsArray) nparray = np.array((double[])values, np_dt); else - nparray = Convert.ToDouble(values); + nparray = Converts.ToDouble(values); break; case "String": if (values.GetType().IsArray) nparray = np.array((string[])values, np_dt); else - nparray = NDArray.FromString(Convert.ToString(values)); + nparray = NDArray.FromString(Converts.ToString(values)); break; case "Boolean": if (values.GetType().IsArray) nparray = np.array((bool[])values, np_dt); else - nparray = Convert.ToBoolean(values); + nparray = Converts.ToBoolean(values); break; default: throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); @@ -226,13 +224,13 @@ namespace Tensorflow } } - var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype); + var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); if (numpy_dtype == TF_DataType.DtInvalid) throw new TypeError($"Unrecognized data type: {nparray.dtype}"); // If dtype was specified and is a quantized type, we convert // numpy_dtype back into the quantized version. - if (is_quantized) + if (quantized_types.Contains(dtype)) numpy_dtype = dtype; bool is_same_size = false; diff --git a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs new file mode 100644 index 00000000..e129edce --- /dev/null +++ b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class ExponentialMovingAverage + { + float _decay; + int? _num_updates; + bool _zero_debias; + string _name; + public string name => _name; + List _averages; + + public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, + string name = "ExponentialMovingAverage") + { + _decay = decay; + _num_updates = num_updates; + _zero_debias = zero_debias; + _name = name; + _averages = new List(); + } + + /// + /// Maintains moving averages of variables. + /// + /// + /// + public Operation apply(RefVariable[] var_list = null) + { + if (var_list == null) + var_list = variables.trainable_variables() as RefVariable[]; + + foreach(var var in var_list) + { + if (!_averages.Contains(var)) + { + ops.init_scope(); + var slot = new SlotCreator(); + var.initialized_value(); + // var avg = slot.create_zeros_slot + } + } + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index c031da54..bb8fcd7a 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -198,7 +198,7 @@ namespace Tensorflow if (!tf.context.executing_eagerly()) { - var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List; + var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List; if (train_op != null && train_op.Contains(apply_updates)) train_op.Add(apply_updates); } @@ -359,7 +359,7 @@ namespace Tensorflow var tmp = variables.trainable_variables(); - var vars = ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); + var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); switch (tmp) { case List values: @@ -370,7 +370,7 @@ namespace Tensorflow break; } - var_list = var_list.Concat(ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); + var_list = var_list.Concat(ops.get_collection(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); var var_refs = processors.Select(x => x.target()).ToArray(); diff --git a/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs b/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs new file mode 100644 index 00000000..02b8bb73 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs @@ -0,0 +1,94 @@ +using System; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using NumSharp.Backends.Unmanaged; + +namespace Tensorflow.Util +{ + public static class UnmanagedExtensions + { + //internally UnmanagedMemoryStream can't construct with null address. + private static readonly unsafe byte* _empty = (byte*) Marshal.AllocHGlobal(1); + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be default/null. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock block) + { + unsafe + { + if (block.Address == null) + return new UnmanagedMemoryStream(_empty, 0); + return new UnmanagedMemoryStream(block.Address, block.BytesCount); + } + } + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be default/null. + /// Offset from the start of the block. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock block, long offset) + { + if (block.BytesCount - offset <= 0) + throw new ArgumentOutOfRangeException(nameof(offset)); + + unsafe + { + if (block.Address == null) + return new UnmanagedMemoryStream(_empty, 0); + return new UnmanagedMemoryStream(block.Address + offset, block.BytesCount - offset); + } + } + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be IntPtr.Zero. + /// The length of the block in bytes. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this IntPtr address, long length) + { + if (length <= 0) + throw new ArgumentOutOfRangeException(nameof(length)); + + unsafe + { + if (address == IntPtr.Zero) + return new UnmanagedMemoryStream(_empty, 0); + + // ReSharper disable once AssignNullToNotNullAttribute + return new UnmanagedMemoryStream((byte*) address, length); + } + } + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be IntPtr.Zero. + /// Offset from the start of the block. + /// The length of the block in bytes. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this IntPtr address, long offset, long length) + { + if (length <= 0) + throw new ArgumentOutOfRangeException(nameof(length)); + + unsafe + { + if (address == IntPtr.Zero) + return new UnmanagedMemoryStream(_empty, 0); + + return new UnmanagedMemoryStream((byte*) address + offset, length); + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 9ac7e6ea..e0e3e0f7 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -121,7 +121,7 @@ namespace Tensorflow if(collections == null) { - collections = new List { ops.GraphKeys.GLOBAL_VARIABLES }; + collections = new List { tf.GraphKeys.GLOBAL_VARIABLES }; } // Store the graph key so optimizers know how to only retrieve variables from @@ -129,8 +129,8 @@ namespace Tensorflow _graph_key = ops.get_default_graph().graph_key; _trainable = trainable; - if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) - collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); + if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) + collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); ops.init_scope(); var values = init_from_fn ? new object[0] : new object[] { initial_value }; @@ -158,7 +158,7 @@ namespace Tensorflow // Or get the initial value from a Tensor or Python object. else { - _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); + _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); var shape = _initial_value.shape; dtype = _initial_value.dtype; @@ -308,5 +308,28 @@ namespace Tensorflow { throw new NotImplementedException(); } + + /// + /// Returns the value of this variable, read in the current context. + /// + /// + private ITensorOrOperation read_value() + { + return array_ops.identity(_variable, name: "read"); + } + + public Tensor is_variable_initialized(RefVariable variable) + { + return state_ops.is_variable_initialized(variable); + } + + public Tensor initialized_value() + { + ops.init_scope(); + throw new NotImplementedException(""); + /*return control_flow_ops.cond(is_variable_initialized(this), + read_value, + () => initial_value);*/ + } } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index af34a2ba..5c8744b6 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using System.Collections.Generic; using Tensorflow.Eager; @@ -145,5 +146,10 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); return _op.outputs[0]; } + + public static Tensor is_variable_initialized(RefVariable @ref, string name = null) + { + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 502c3c1e..8f478f2d 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -106,5 +106,13 @@ namespace Tensorflow throw new NotImplementedException("scatter_add"); } + + public static Tensor is_variable_initialized(RefVariable @ref, string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.is_variable_initialized(@ref: @ref, name: name); + throw new NotImplementedException(""); + //return @ref.is_initialized(name: name); + } } } diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 3880bc7f..6e9d0e4c 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow { @@ -28,7 +29,7 @@ namespace Tensorflow /// public static object trainable_variables() { - return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); + return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES); } /// @@ -40,11 +41,11 @@ namespace Tensorflow { var all = new List(); - var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); + var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); if(collection != null) all.AddRange(collection as List); - collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); + collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); if (collection != null) all.AddRange(collection as List); @@ -64,7 +65,7 @@ namespace Tensorflow /// A list of `Variable` objects. public static List global_variables(string scope = null) { - var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); + var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); return result == null ? new List() : result as List; } diff --git a/src/TensorFlowNET.Core/globals.regen b/src/TensorFlowNET.Core/globals.regen new file mode 100644 index 00000000..86cbee67 --- /dev/null +++ b/src/TensorFlowNET.Core/globals.regen @@ -0,0 +1,40 @@ +%all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] +%all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] + +%supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] +%supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] + +%supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] +%supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] +%supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] +%supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] +%supported_numericals_TF_DataType = ["TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] +%supported_numericals_TF_DataType_full = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] + +//this is the type we use in summerizing/reducting: +%supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] +%supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] + +%supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"] +%supported_numericals_signed_lowercase = ["short","int","long","double","float"] +%supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"] +%supported_numericals_signed_onevales = ["1","1","1L","1d","1f"] + +%supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"] +%supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"] +%supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"] +%supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] + +%supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] +%supported_dtypes_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] +%supported_dtypes_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] + +%supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] +%supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] +%supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"] +%supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"] + +//this is the type we use in summerizing/reducting: +%supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] +%supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"] + diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 94e1b8d5..c5a06433 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -27,57 +27,113 @@ namespace Tensorflow /// specified, but it is also possible to pass an explicit list of /// variables. /// - public static class GraphKeys + public class GraphKeys { + #region const + + + /// + /// the subset of `Variable` objects that will be trained by an optimizer. + /// + public const string TRAINABLE_VARIABLES_ = "trainable_variables"; + + /// + /// Trainable resource-style variables. + /// + public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; + + /// + /// Key for streaming model ports. + /// + public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; + + /// + /// Key to collect losses + /// + public const string LOSSES_ = "losses"; + + /// + /// Key to collect Variable objects that are global (shared across machines). + /// Default collection for all variables, except local ones. + /// + public const string GLOBAL_VARIABLES_ = "variables"; + + public const string TRAIN_OP_ = "train_op"; + + public const string GLOBAL_STEP_ = "global_step"; + + public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; + /// + /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. + /// + public const string SAVEABLE_OBJECTS_ = "saveable_objects"; + /// + /// Key to collect update_ops + /// + public const string UPDATE_OPS_ = "update_ops"; + + // Key to collect summaries. + public const string SUMMARIES_ = "summaries"; + + // Used to store v2 summary names. + public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; + + // Key for control flow context. + public const string COND_CONTEXT_ = "cond_context"; + public const string WHILE_CONTEXT_ = "while_context"; + + #endregion + + /// /// the subset of `Variable` objects that will be trained by an optimizer. /// - public static string TRAINABLE_VARIABLES = "trainable_variables"; + public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; /// /// Trainable resource-style variables. /// - public static string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; + public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; /// /// Key for streaming model ports. /// - public static string _STREAMING_MODEL_PORTS = "streaming_model_ports"; + public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; /// /// Key to collect losses /// - public const string LOSSES = "losses"; + public string LOSSES => LOSSES_; /// /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. /// - public static string GLOBAL_VARIABLES = "variables"; + public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; - public static string TRAIN_OP = "train_op"; + public string TRAIN_OP => TRAIN_OP_; - public static string GLOBAL_STEP = GLOBAL_STEP = "global_step"; + public string GLOBAL_STEP => GLOBAL_STEP_; - public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; + public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; /// /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. /// - public static string SAVEABLE_OBJECTS = "saveable_objects"; + public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; /// /// Key to collect update_ops /// - public static string UPDATE_OPS = "update_ops"; + public string UPDATE_OPS => UPDATE_OPS_; // Key to collect summaries. - public const string SUMMARIES = "summaries"; + public string SUMMARIES => SUMMARIES_; // Used to store v2 summary names. - public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; + public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; // Key for control flow context. - public static string COND_CONTEXT = "cond_context"; - public static string WHILE_CONTEXT = "while_context"; + public string COND_CONTEXT => COND_CONTEXT_; + public string WHILE_CONTEXT => WHILE_CONTEXT_; } } } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index b2945f13..1dc8eb56 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -230,8 +230,8 @@ namespace Tensorflow // Add attrs foreach (var attr in node_def.Attr) { - var bytes = attr.Value.ToByteArray(); - var proto = Marshal.AllocHGlobal(bytes.Length); + var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. + var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak Marshal.Copy(bytes, 0, proto, bytes.Length); uint len = (uint)bytes.Length; c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); @@ -488,6 +488,8 @@ namespace Tensorflow switch (value) { + case String str: + return constant_op.constant(str, dtype: TF_DataType.TF_STRING, name: name); case NDArray nd: return constant_op.constant(nd, dtype: dtype, name: name); case Tensor tensor: diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index da873722..ca903844 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -64,13 +64,12 @@ namespace Tensorflow public Session Session() { - defaultSession = new Session(); - return defaultSession; + return new Session(); } - public Session Session(Graph graph) + public Session Session(Graph graph, SessionOptions opts = null) { - return new Session(graph); + return new Session(graph, opts: opts); } public Session Session(SessionOptions opts) diff --git a/src/TensorFlowNet.Benchmarks/Program.cs b/src/TensorFlowNet.Benchmarks/Program.cs index e17a1d68..ea7c2bde 100644 --- a/src/TensorFlowNet.Benchmarks/Program.cs +++ b/src/TensorFlowNet.Benchmarks/Program.cs @@ -9,24 +9,18 @@ namespace TensorFlowBenchmark { static void Main(string[] args) { -#if DEBUG - IConfig config = new DebugInProcessConfig(); -#else - IConfig config = null; -#endif - if (args?.Length > 0) { for (int i = 0; i < args.Length; i++) { string name = $"TensorFlowBenchmark.{args[i]}"; var type = Type.GetType(name); - BenchmarkRunner.Run(type, config); + BenchmarkRunner.Run(type); } } else { - BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, config); + BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator)); } Console.ReadLine(); diff --git a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj index bc2a0ff3..4618f06b 100644 --- a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj @@ -6,6 +6,7 @@ true TensorFlowBenchmark TensorFlowBenchmark + 7.3 diff --git a/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs b/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs new file mode 100644 index 00000000..5b3a0cd3 --- /dev/null +++ b/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs @@ -0,0 +1,76 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using BenchmarkDotNet.Attributes; +using Google.Protobuf.WellKnownTypes; +using NumSharp; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowBenchmark.Unmanaged +{ + public struct UnmanagedStruct + { + public int a; + public long b; + public UnmanagedStruct(int _) + { + a = 2; + b = 3; + } + } + + [SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)] + [MinColumn, MaxColumn, MeanColumn, MedianColumn] + public unsafe class StructCastBenchmark + { + private static void EnsureIsUnmanaged(T _) where T : unmanaged + { } + + static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile. + => EnsureIsUnmanaged(new UnmanagedStruct()); + + private IntPtr data; + private void* dataptr; + + [GlobalSetup] + public void Setup() + { + data = Marshal.AllocHGlobal(Marshal.SizeOf()); + dataptr = data.ToPointer(); + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void Marshal_PtrToStructure() + { + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = Marshal.PtrToStructure(data); + } + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void PointerCast() + { + var dptr = dataptr; + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = *(UnmanagedStruct*) dptr; + } + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void Unsafe_Read() + { + var dptr = dataptr; + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = Unsafe.Read(dptr); + } + } + + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs index a9dfbe7e..9b33b28f 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs @@ -80,26 +80,18 @@ namespace TensorFlowNET.Examples for (int epoch = 0; epoch < training_epochs; epoch++) { foreach (var (x, y) in zip(train_X, train_Y)) - { - sess.run(optimizer, - new FeedItem(X, x), - new FeedItem(Y, y)); - } + sess.run(optimizer, (X, x), (Y, y)); // Display logs per epoch step if ((epoch + 1) % display_step == 0) { - var c = sess.run(cost, - new FeedItem(X, train_X), - new FeedItem(Y, train_Y)); + var c = sess.run(cost, (X, train_X), (Y, train_Y)); Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); } } Console.WriteLine("Optimization Finished!"); - var training_cost = sess.run(cost, - new FeedItem(X, train_X), - new FeedItem(Y, train_Y)); + var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); // Testing example @@ -107,8 +99,7 @@ namespace TensorFlowNET.Examples var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); Console.WriteLine("Testing... (Mean square loss Comparison)"); var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), - new FeedItem(X, test_X), - new FeedItem(Y, test_Y)); + (X, test_X), (Y, test_Y)); Console.WriteLine($"Testing cost={testing_cost}"); var diff = Math.Abs((float)training_cost - (float)testing_cost); Console.WriteLine($"Absolute mean square loss difference: {diff}"); diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index 73d40d28..3116e6f4 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples // Display logs per epoch step if ((epoch + 1) % display_step == 0) - print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); + print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); sw.Reset(); } @@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); // Calculate accuracy var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); - float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); - print($"Accuracy: {acc.ToString("F4")}"); + float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); + print($"Accuracy: {acc:F4}"); return acc > 0.9; } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs index 50093f3c..d0c06704 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples public void PrepareData() { // get model file - string url = "http://download.tf.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; + string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index 79cc548f..7f2d81f4 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -21,6 +21,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; +using System.Threading.Tasks; using Tensorflow; using TensorFlowNET.Examples.Utility; using static Tensorflow.Binding; @@ -381,10 +382,15 @@ namespace TensorFlowNET.Examples Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) { int how_many_bottlenecks = 0; - foreach (var (label_name, label_lists) in image_lists) + var kvs = image_lists.ToArray(); + var categories = new string[] {"training", "testing", "validation"}; + Parallel.For(0, kvs.Length, i => { - foreach (var category in new string[] { "training", "testing", "validation" }) + var (label_name, label_lists) = kvs[i]; + + Parallel.For(0, categories.Length, j => { + var category = categories[j]; var category_list = label_lists[category]; foreach (var (index, unused_base_name) in enumerate(category_list)) { @@ -395,8 +401,8 @@ namespace TensorFlowNET.Examples if (how_many_bottlenecks % 300 == 0) print($"{how_many_bottlenecks} bottleneck files created."); } - } - } + }); + }); } private float[] get_or_create_bottleneck(Session sess, Dictionary> image_lists, @@ -508,7 +514,7 @@ namespace TensorFlowNET.Examples { // get a set of images to teach the network about the new classes string fileName = "flower_photos.tgz"; - string url = $"http://download.tf.org/example_images/{fileName}"; + string url = $"http://download.tensorflow.org/example_images/{fileName}"; Web.Download(url, data_dir, fileName); Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs new file mode 100644 index 00000000..482280ca --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs @@ -0,0 +1,54 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + public class Dataset + { + string annot_path; + int[] input_sizes; + int batch_size; + bool data_aug; + int[] train_input_sizes; + NDArray strides; + NDArray anchors; + Dictionary classes; + int num_classes; + int anchor_per_scale; + int max_bbox_per_scale; + string[] annotations; + int num_samples; + int batch_count; + + public int Length = 0; + + public Dataset(string dataset_type, Config cfg) + { + annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH; + input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE; + batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE; + data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG; + train_input_sizes = cfg.TRAIN.INPUT_SIZE; + strides = np.array(cfg.YOLO.STRIDES); + + classes = Utils.read_class_names(cfg.YOLO.CLASSES); + num_classes = classes.Count; + anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS)); + anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; + max_bbox_per_scale = 150; + + annotations = load_annotations(); + num_samples = len(annotations); + batch_count = 0; + } + + string[] load_annotations() + { + return File.ReadAllLines(annot_path); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs new file mode 100644 index 00000000..c3201f8c --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -0,0 +1,139 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + /// + /// Implementation of YOLO v3 object detector in Tensorflow + /// https://github.com/YunYang1994/tensorflow-yolov3 + /// + public class Main : IExample + { + public bool Enabled { get; set; } = true; + public bool IsImportingGraph { get; set; } = false; + public string Name => "YOLOv3"; + + #region args + Dictionary classes; + int num_classes; + float learn_rate_init; + float learn_rate_end; + int first_stage_epochs; + int second_stage_epochs; + int warmup_periods; + string time; + float moving_ave_decay; + int max_bbox_per_scale; + int steps_per_period; + + Dataset trainset, testset; + + Config cfg; + + Tensor input_data; + Tensor label_sbbox; + Tensor label_mbbox; + Tensor label_lbbox; + Tensor true_sbboxes; + Tensor true_mbboxes; + Tensor true_lbboxes; + Tensor trainable; + + Session sess; + YOLOv3 model; + #endregion + + public bool Run() + { + PrepareData(); + + var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); + + var options = new SessionOptions(); + options.SetConfig(new ConfigProto { AllowSoftPlacement = true }); + using (var sess = tf.Session(graph, opts: options)) + { + Train(sess); + } + + return true; + } + + public void Train(Session sess) + { + + } + + public void Test(Session sess) + { + throw new NotImplementedException(); + } + + public Graph BuildGraph() + { + var graph = new Graph().as_default(); + + tf_with(tf.name_scope("define_input"), scope => + { + input_data = tf.placeholder(dtype: tf.float32, name: "input_data"); + label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox"); + label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox"); + label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox"); + true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes"); + true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes"); + true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes"); + trainable = tf.placeholder(dtype: tf.@bool, name: "training"); + }); + + tf_with(tf.name_scope("define_loss"), scope => + { + model = new YOLOv3(cfg, input_data, trainable); + }); + + tf_with(tf.name_scope("define_weight_decay"), scope => + { + var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables()); + }); + + return graph; + } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public void Predict(Session sess) + { + throw new NotImplementedException(); + } + + public void PrepareData() + { + cfg = new Config(Name); + + string dataDir = Path.Combine(Name, "data"); + Directory.CreateDirectory(dataDir); + + classes = Utils.read_class_names(cfg.YOLO.CLASSES); + num_classes = classes.Count; + + learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT; + learn_rate_end = cfg.TRAIN.LEARN_RATE_END; + first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS; + second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS; + warmup_periods = cfg.TRAIN.WARMUP_EPOCHS; + DateTime now = DateTime.Now; + time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}"; + moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY; + max_bbox_per_scale = 150; + trainset = new Dataset("train", cfg); + testset = new Dataset("test", cfg); + steps_per_period = trainset.Length; + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs new file mode 100644 index 00000000..3a0d3089 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs @@ -0,0 +1,27 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + class Utils + { + public static Dictionary read_class_names(string file) + { + var classes = new Dictionary(); + foreach (var line in File.ReadAllLines(file)) + classes[classes.Count] = line; + return classes; + } + + public static NDArray get_anchors(string file) + { + return np.array(File.ReadAllText(file).Split(',') + .Select(x => float.Parse(x)) + .ToArray()).reshape(3, 3, 2); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs new file mode 100644 index 00000000..de5f0acc --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs @@ -0,0 +1,65 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + public class YOLOv3 + { + Config cfg; + Tensor trainable; + Tensor input_data; + Dictionary classes; + int num_class; + NDArray strides; + NDArray anchors; + int anchor_per_scale; + float iou_loss_thresh; + string upsample_method; + Tensor conv_lbbox; + Tensor conv_mbbox; + Tensor conv_sbbox; + + public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_) + { + cfg = cfg_; + input_data = input_data_; + trainable = trainable_; + classes = Utils.read_class_names(cfg.YOLO.CLASSES); + num_class = len(classes); + strides = np.array(cfg.YOLO.STRIDES); + anchors = Utils.get_anchors(cfg.YOLO.ANCHORS); + anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; + iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH; + upsample_method = cfg.YOLO.UPSAMPLE_METHOD; + + (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data); + + tf_with(tf.variable_scope("pred_sbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); + + tf_with(tf.variable_scope("pred_mbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); + + tf_with(tf.variable_scope("pred_lbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); + } + + private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data) + { + Tensor route_1, route_2; + (route_1, route_2, input_data) = backbone.darknet53(input_data, trainable); + + return (conv_lbbox, conv_mbbox, conv_sbbox); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs new file mode 100644 index 00000000..0e7b1446 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + class backbone + { + public static (Tensor, Tensor, Tensor) darknet53(Tensor input_data, Tensor trainable) + { + return tf_with(tf.variable_scope("darknet"), scope => + { + input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 3, 32 }, trainable: trainable, name: "conv0"); + input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 32, 64 }, trainable: trainable, name: "conv1", downsample: true); + + foreach (var i in range(1)) + input_data = common.residual_block(input_data, 64, 32, 64, trainable: trainable, name: $"residual{i + 0}"); + + var route_1 = input_data; + var route_2 = input_data; + + return (route_1, route_2, input_data); + }); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs new file mode 100644 index 00000000..57105aa1 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + class common + { + public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable, + string name, bool downsample = false, bool activate = true, + bool bn = true) + { + return tf_with(tf.variable_scope(name), scope => + { + int[] strides; + string padding; + + if (downsample) + { + throw new NotImplementedException(""); + } + else + { + strides = new int[] { 1, 1, 1, 1 }; + padding = "SAME"; + } + + var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true, + shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f)); + + var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding); + + if (bn) + { + conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer, + gamma_initializer: tf.ones_initializer, + moving_mean_initializer: tf.zeros_initializer, + moving_variance_initializer: tf.ones_initializer, training: trainable); + } + else + { + throw new NotImplementedException(""); + } + + if (activate) + conv = tf.nn.leaky_relu(conv, alpha: 0.1f); + + return conv; + }); + } + + public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1, + int filter_num2, Tensor trainable, string name) + { + var short_cut = input_data; + + return tf_with(tf.variable_scope(name), scope => + { + input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 }, + trainable: trainable, name: "conv1"); + input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 }, + trainable: trainable, name: "conv2"); + + var residual_output = input_data + short_cut; + + return residual_output; + }); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs new file mode 100644 index 00000000..b5c46151 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + public class Config + { + public YoloConfig YOLO; + public TrainConfig TRAIN; + public TestConfig TEST; + + public Config(string root) + { + YOLO = new YoloConfig(root); + TRAIN = new TrainConfig(root); + TEST = new TestConfig(root); + } + + public class YoloConfig + { + string _root; + + public string CLASSES; + public string ANCHORS; + public float MOVING_AVE_DECAY = 0.9995f; + public int[] STRIDES = new int[] { 8, 16, 32 }; + public int ANCHOR_PER_SCALE = 3; + public float IOU_LOSS_THRESH = 0.5f; + public string UPSAMPLE_METHOD = "resize"; + public string ORIGINAL_WEIGHT; + public string DEMO_WEIGHT; + + public YoloConfig(string root) + { + _root = root; + CLASSES = Path.Combine(_root, "data", "classes", "coco.names"); + ANCHORS = Path.Combine(_root, "data", "anchors", "basline_anchors.txt"); + ORIGINAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco.ckpt"); + DEMO_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); + } + } + + public class TrainConfig + { + string _root; + + public int BATCH_SIZE = 6; + public int[] INPUT_SIZE = new int[] { 320, 352, 384, 416, 448, 480, 512, 544, 576, 608 }; + public bool DATA_AUG = true; + public float LEARN_RATE_INIT = 1e-4f; + public float LEARN_RATE_END = 1e-6f; + public int WARMUP_EPOCHS = 2; + public int FISRT_STAGE_EPOCHS = 20; + public int SECOND_STAGE_EPOCHS = 30; + public string INITIAL_WEIGHT; + public string ANNOT_PATH; + + public TrainConfig(string root) + { + _root = root; + INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt"); + ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); + } + } + + public class TestConfig + { + string _root; + + public int BATCH_SIZE = 2; + public int[] INPUT_SIZE = new int[] { 544 }; + public bool DATA_AUG = false; + public bool WRITE_IMAGE = true; + public string WRITE_IMAGE_PATH; + public string WEIGHT_FILE; + public bool WRITE_IMAGE_SHOW_LABEL = true; + public bool SHOW_LABEL = true; + public int SECOND_STAGE_EPOCHS = 30; + public float SCORE_THRESHOLD = 0.3f; + public float IOU_THRESHOLD = 0.45f; + public string ANNOT_PATH; + + public TestConfig(string root) + { + _root = root; + ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_test.txt"); + WRITE_IMAGE_PATH = Path.Combine(_root, "data", "detection"); + WEIGHT_FILE = Path.Combine(_root, "checkpoint", "yolov3_test_loss=9.2099.ckpt-5"); + } + } + } +} diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj index 1bd3d530..55e9b27d 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj @@ -6,6 +6,14 @@ false + + bin\debug-gpu + + + + bin\release-gpu + + diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index f4e2340a..c675bedc 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -6,6 +6,10 @@ false + + DEBUG;TRACE + + diff --git a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs index 9b28fdc0..6150fa90 100644 --- a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text var y_one_hot = tf.one_hot(y, num_class); loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); - var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List; + var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List; tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate { var adam = tf.train.AdamOptimizer(learning_rate); diff --git a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs index 15d9b819..6c593929 100644 --- a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs +++ b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics @@ -14,21 +15,22 @@ namespace TensorFlowNET.UnitTest.Basics var expected = new[] { false, true, false, false, true, false, true }; var spike = tf.Variable(false); - - spike.initializer.run(); - foreach (var i in range(1, 2)) + using (var sess = new Session()) { - if (raw_data[i] - raw_data[i - 1] > 5d) - { - var updater = tf.assign(spike, tf.constant(true)); - updater.eval(); - } - else + spike.initializer.run(session: sess); + foreach (var i in range(1, 2)) { - tf.assign(spike, tf.constant(true)).eval(); - } + if (raw_data[i] - raw_data[i - 1] > 5d) + { + var updater = tf.assign(spike, tf.constant(true)); + updater.eval(sess); + } else + { + tf.assign(spike, tf.constant(true)).eval(sess); + } - Assert.AreEqual((bool)spike.eval(), expected[i - 1]); + Assert.AreEqual((bool) spike.eval(), expected[i - 1]); + } } } } diff --git a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs index 33e38870..58609c17 100644 --- a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs @@ -2,6 +2,7 @@ using NumSharp; using System; using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest @@ -45,15 +46,18 @@ namespace TensorFlowNET.UnitTest private bool GetGraphDef(Graph graph, out GraphDef graph_def) { graph_def = null; - var s = new Status(); - var buffer = new Buffer(); - c_api.TF_GraphToGraphDef(graph, buffer, s); - bool ret = TF_GetCode(s) == TF_OK; - EXPECT_EQ(TF_OK, TF_GetCode(s)); - if (ret) graph_def = GraphDef.Parser.ParseFrom(buffer.Data); - buffer.Dispose(); - s.Dispose(); - return ret; + using (var s = new Status()) + { + using (var buffer = new Buffer()) + { + c_api.TF_GraphToGraphDef(graph, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)); + if (ret) + graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + return ret; + } + } } private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index 33e88286..ae57b075 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -40,10 +40,7 @@ namespace TensorFlowNET.UnitTest private void DeleteInputValues() { - for (var i = 0; i < input_values_.Count; ++i) - { - input_values_[i].Dispose(); - } + //clearing is enough as they will be disposed by the GC unless they are referenced else-where. input_values_.Clear(); } @@ -60,11 +57,7 @@ namespace TensorFlowNET.UnitTest private void ResetOutputValues() { - for (var i = 0; i < output_values_.Count; ++i) - { - if (output_values_[i] != IntPtr.Zero) - output_values_[i].Dispose(); - } + //clearing is enough as they will be disposed by the GC unless they are referenced else-where. output_values_.Clear(); } diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index c1d4c9e5..b532e558 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -98,9 +98,9 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(tensor); - Assert.AreEqual(result[0].shape[0], 3); - Assert.AreEqual(result[0].shape[1], 2); - Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result[0].Data())); + Assert.AreEqual(result.shape[0], 3); + Assert.AreEqual(result.shape[1], 2); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data())); } // big size @@ -109,13 +109,13 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(tensor); - Assert.AreEqual(result[0].shape[0], 200); - Assert.AreEqual(result[0].shape[1], 100); + Assert.AreEqual(result.shape[0], 200); + Assert.AreEqual(result.shape[1], 100); - var data = result[0].Data(); + var data = result.Data(); Assert.AreEqual(0, data[0]); Assert.AreEqual(0, data[500]); - Assert.AreEqual(0, data[result[0].size - 1]); + Assert.AreEqual(0, data[result.size - 1]); } } @@ -127,9 +127,9 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(ones); - Assert.AreEqual(result[0].shape[0], 3); - Assert.AreEqual(result[0].shape[1], 2); - Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result[0].Data())); + Assert.AreEqual(result.shape[0], 3); + Assert.AreEqual(result.shape[1], 2); + Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data())); } } @@ -142,9 +142,9 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(halfes); - Assert.AreEqual(result[0].shape[0], 3); - Assert.AreEqual(result[0].shape[1], 2); - Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result[0].Data())); + Assert.AreEqual(result.shape[0], 3); + Assert.AreEqual(result.shape[1], 2); + Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data())); } } @@ -161,10 +161,10 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var result = sess.run(tensor); - var data = result[0].Data(); + var data = result.Data(); - Assert.AreEqual(result[0].shape[0], 2); - Assert.AreEqual(result[0].shape[1], 3); + Assert.AreEqual(result.shape[0], 2); + Assert.AreEqual(result.shape[1], 3); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); } } @@ -177,7 +177,7 @@ namespace TensorFlowNET.UnitTest var c = a * b; var sess = tf.Session(); - double result = sess.run(c)[0]; + double result = sess.run(c); sess.close(); Assert.AreEqual(6.0, result); diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs index b52bc1cf..c8e57ba4 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest.cs @@ -41,7 +41,7 @@ namespace TensorFlowNET.UnitTest var grad = tf.gradients(y, x); Assert.AreEqual(grad[0].name, "gradients/AddN:0"); - float r = sess.run(grad[0])[0]; + float r = sess.run(grad[0]); Assert.AreEqual(r, 1.4f); } } @@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest var grad = tf.gradients(y, x); Assert.AreEqual(grad[0].name, "gradients/AddN:0"); - float r = sess.run(grad[0])[0]; + float r = sess.run(grad[0]); Assert.AreEqual(r, 14.700001f); }); } @@ -94,7 +94,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session(graph)) { - var r = sess.run(slice)[0]; + var r = sess.run(slice); Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 })); Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData(), new[] { 11, 13 })); diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index f5431e01..443191dd 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -322,7 +322,6 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(feed2, control_inputs[1]); // Export to a graph def so we can import a graph with control dependencies - graph_def.Dispose(); graph_def = new Buffer(); c_api.TF_GraphToGraphDef(graph, graph_def, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); @@ -346,14 +345,10 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(feed4, control_inputs[1]); c_api.TF_DeleteImportGraphDefOptions(opts); - c_api.TF_DeleteBuffer(graph_def); // Can add nodes to the imported graph without trouble. c_test_util.Add(feed, scalar, graph, s); ASSERT_EQ(TF_Code.TF_OK, s.Code); - - graph.Dispose(); - s.Dispose(); } /// @@ -416,12 +411,13 @@ namespace TensorFlowNET.UnitTest } + [TestMethod] public void ImportGraphMeta() { var dir = "my-save-dir/"; using (var sess = tf.Session()) { - var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); + var new_saver = tf.train.import_meta_graph(@"D:\tmp\resnet_v2_101_2017_04_14\eval.graph"); new_saver.restore(sess, dir + "my-model-10000"); var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); var batch_size = tf.size(labels); diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs new file mode 100644 index 00000000..e4f8a835 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -0,0 +1,33 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + /// + /// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file + /// + [TestClass] + public class ImageTest + { + string imgPath = "../../../../../data/shasta-daisy.jpg"; + Tensor contents; + + public ImageTest() + { + imgPath = Path.GetFullPath(imgPath); + contents = tf.read_file(imgPath); + } + + [TestMethod] + public void decode_image() + { + var img = tf.image.decode_image(contents); + Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0"); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs index 4ff50deb..7a9ae062 100644 --- a/test/TensorFlowNET.UnitTest/NameScopeTest.cs +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -1,4 +1,5 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow; using static Tensorflow.Binding; @@ -42,5 +43,37 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual("", g._name_stack); } + + [TestMethod, Ignore("Unimplemented Usage")] + public void NestedNameScope_Using() + { + Graph g = tf.Graph().as_default(); + + using (var name = new ops.NameScope("scope1")) + { + Assert.AreEqual("scope1", g._name_stack); + Assert.AreEqual("scope1/", name); + + var const1 = tf.constant(1.0); + Assert.AreEqual("scope1/Const:0", const1.name); + + using (var name2 = new ops.NameScope("scope2")) + { + Assert.AreEqual("scope1/scope2", g._name_stack); + Assert.AreEqual("scope1/scope2/", name); + + var const2 = tf.constant(2.0); + Assert.AreEqual("scope1/scope2/Const:0", const2.name); + } + + Assert.AreEqual("scope1", g._name_stack); + var const3 = tf.constant(2.0); + Assert.AreEqual("scope1/Const_1:0", const3.name); + }; + + g.Dispose(); + + Assert.AreEqual("", g._name_stack); + } } } diff --git a/test/TensorFlowNET.UnitTest/Open.snk b/test/TensorFlowNET.UnitTest/Open.snk new file mode 100644 index 00000000..22a3cbd2 Binary files /dev/null and b/test/TensorFlowNET.UnitTest/Open.snk differ diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 4c6ae3d0..226a4839 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using NumSharp; using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; using static Tensorflow.Binding; @@ -21,7 +22,7 @@ namespace TensorFlowNET.UnitTest { var handle = c_api.TF_GetAllOpList(); var buffer = new Buffer(handle); - var op_list = OpList.Parser.ParseFrom(buffer); + var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); var _registered_ops = new Dictionary(); foreach (var op_def in op_list.Op) @@ -44,7 +45,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); - Assert.AreEqual((float)o[0], 5.0f); + Assert.AreEqual((float)o, 5.0f); } } @@ -58,7 +59,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(c); - Assert.AreEqual((float)o[0], 9.0f); + Assert.AreEqual((float)o, 9.0f); } } @@ -72,7 +73,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } } @@ -86,7 +87,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } } @@ -100,7 +101,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } b = tf.cumsum(a, exclusive: true); @@ -109,7 +110,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } b = tf.cumsum(a, reverse: true); @@ -118,7 +119,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } b = tf.cumsum(a, exclusive:true, reverse: true); @@ -127,7 +128,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } } @@ -143,7 +144,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } d = tf.cast(tf.logical_not(b), tf.int32); @@ -152,7 +153,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } d = tf.cast(tf.logical_or(b, c), tf.int32); @@ -161,7 +162,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } d = tf.cast(tf.logical_xor(b, c), tf.int32); @@ -170,7 +171,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } } @@ -197,7 +198,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator +(Tensor x, Tensor y)` @@ -207,7 +208,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator +(Tensor x, int y)` @@ -216,7 +217,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator +(int x, Tensor y)` @@ -225,7 +226,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } #endregion @@ -246,7 +247,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator +(Tensor x, Tensor y) @@ -256,7 +257,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator +(Tensor x, float y) @@ -265,7 +266,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator +(float x, Tensor y) @@ -274,7 +275,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } #endregion @@ -295,7 +296,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator +(Tensor x, Tensor y) @@ -305,7 +306,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator +(Tensor x, double y) @@ -314,7 +315,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator +(double x, Tensor y) @@ -323,7 +324,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } #endregion } @@ -352,7 +353,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator -(Tensor x, Tensor y) @@ -362,7 +363,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator -(Tensor x, int y) @@ -371,7 +372,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator -(int x, Tensor y) @@ -380,7 +381,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], Math.Abs(intResult)); + Assert.AreEqual((int)o, Math.Abs(intResult)); } // Testing `operator -(Tensor x) @@ -389,7 +390,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -411,7 +412,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator -(Tensor x, Tensor y) @@ -421,7 +422,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator -(Tensor x, float y) @@ -430,7 +431,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator -(float x, Tensor y) @@ -439,7 +440,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], Math.Abs(floatResult)); + Assert.AreEqual((float)o, Math.Abs(floatResult)); } // Testing `operator -(Tensor x) @@ -448,7 +449,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResultTwo); + Assert.AreEqual((float)o, floatResultTwo); } #endregion @@ -470,7 +471,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator -(Tensor x, Tensor y) @@ -480,7 +481,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator -(Tensor x, double y) @@ -489,7 +490,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator -(double x, Tensor y) @@ -498,7 +499,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], Math.Abs(doubleResult)); + Assert.AreEqual((double)o, Math.Abs(doubleResult)); } // Testing `operator -(Tensor x) @@ -507,7 +508,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResultTwo); + Assert.AreEqual((double)o, doubleResultTwo); } #endregion } @@ -593,7 +594,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator *(Tensor x, Tensor y) @@ -603,7 +604,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator *(Tensor x, int y) @@ -612,7 +613,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator *(int x, Tensor y) @@ -621,7 +622,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } #endregion @@ -642,7 +643,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator *(Tensor x, Tensor y) @@ -652,7 +653,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator *(Tensor x, float y) @@ -661,7 +662,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator *(float x, Tensor y) @@ -670,7 +671,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } #endregion @@ -691,7 +692,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator *(Tensor x, Tensor y) @@ -701,7 +702,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator *(Tensor x, double y) @@ -710,7 +711,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator *(double x, Tensor y) @@ -719,7 +720,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } #endregion } @@ -747,7 +748,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator /(Tensor x, Tensor y) @@ -757,7 +758,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator /(Tensor x, int y) @@ -766,7 +767,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator /(int x, Tensor y) @@ -775,7 +776,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } #endregion @@ -796,7 +797,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator /(Tensor x, Tensor y) @@ -806,7 +807,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator /(Tensor x, float y) @@ -815,7 +816,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator /(float x, Tensor y) @@ -824,7 +825,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } #endregion @@ -845,7 +846,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator /(Tensor x, Tensor y) @@ -855,7 +856,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator /(Tensor x, double y) @@ -864,7 +865,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator /(double x, Tensor y) @@ -873,7 +874,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } #endregion } @@ -901,7 +902,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >(Tensor x, Tensor y) @@ -911,7 +912,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >(Tensor x, int y) @@ -920,7 +921,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >(int x, Tensor y) @@ -929,7 +930,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -950,7 +951,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >(Tensor x, Tensor y) @@ -960,7 +961,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >(Tensor x, float y) @@ -969,7 +970,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >(float x, Tensor y) @@ -978,7 +979,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResultTwo); + Assert.AreEqual((int)o, floatResultTwo); } #endregion @@ -999,7 +1000,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >(Tensor x, Tensor y) @@ -1009,7 +1010,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >(Tensor x, double y) @@ -1018,7 +1019,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >(double x, Tensor y) @@ -1027,7 +1028,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResultTwo); + Assert.AreEqual((int)o, doubleResultTwo); } #endregion } @@ -1055,7 +1056,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <(Tensor x, Tensor y) @@ -1065,7 +1066,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <(Tensor x, int y) @@ -1074,7 +1075,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <(int x, Tensor y) @@ -1083,7 +1084,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -1104,7 +1105,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <(Tensor x, Tensor y) @@ -1114,7 +1115,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <(Tensor x, float y) @@ -1123,7 +1124,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <(float x, Tensor y) @@ -1132,7 +1133,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResultTwo); + Assert.AreEqual((int)o, floatResultTwo); } #endregion @@ -1153,7 +1154,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <(Tensor x, Tensor y) @@ -1163,7 +1164,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <(Tensor x, double y) @@ -1172,7 +1173,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <(double x, Tensor y) @@ -1181,7 +1182,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResultTwo); + Assert.AreEqual((int)o, doubleResultTwo); } #endregion } @@ -1209,7 +1210,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >=(Tensor x, Tensor y) @@ -1219,7 +1220,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >=(Tensor x, int y) @@ -1228,7 +1229,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >=(int x, Tensor y) @@ -1237,7 +1238,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -1258,7 +1259,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >=(Tensor x, Tensor y) @@ -1268,7 +1269,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >=(Tensor x, float y) @@ -1277,7 +1278,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >=(float x, Tensor y) @@ -1286,7 +1287,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResultTwo); + Assert.AreEqual((int)o, floatResultTwo); } #endregion @@ -1307,7 +1308,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >=(Tensor x, Tensor y) @@ -1317,7 +1318,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >=(Tensor x, double y) @@ -1326,7 +1327,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >=(double x, Tensor y) @@ -1335,7 +1336,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResultTwo); + Assert.AreEqual((int)o, doubleResultTwo); } #endregion } @@ -1363,7 +1364,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <=(Tensor x, Tensor y) @@ -1373,7 +1374,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <=(Tensor x, int y) @@ -1382,7 +1383,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <=(int x, Tensor y) @@ -1391,7 +1392,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -1412,7 +1413,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <=(Tensor x, Tensor y) @@ -1422,7 +1423,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <=(Tensor x, float y) @@ -1431,7 +1432,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <=(float x, Tensor y) @@ -1440,7 +1441,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResultTwo); + Assert.AreEqual((int)o, floatResultTwo); } #endregion @@ -1461,7 +1462,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <=(Tensor x, Tensor y) @@ -1471,7 +1472,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <=(Tensor x, double y) @@ -1480,7 +1481,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <=(double x, Tensor y) @@ -1489,7 +1490,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResultTwo); + Assert.AreEqual((int)o, doubleResultTwo); } #endregion } diff --git a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs index 14b16c23..5135bd25 100644 --- a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs +++ b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs @@ -17,7 +17,7 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(y, new FeedItem(x, 2)); - Assert.AreEqual((int)result[0], 6); + Assert.AreEqual((int)result, 6); } } } diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 701b4b4b..d2ae36d7 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -165,7 +165,7 @@ namespace TensorFlowNET.UnitTest { using (var sess = tf.Session()) { - var ndarray=tensor.eval(); + var ndarray=tensor.eval(sess); if (typeof(T) == typeof(double)) { double x = ndarray; diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 9c8485ec..62d7c63d 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(0, outTensor.NDims); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - var output_contents = outTensor.Data(); + var output_contents = outTensor.ToArray(); EXPECT_EQ(3 + 2, output_contents[0]); // Add another operation to the graph. @@ -66,14 +66,12 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(0, outTensor.NDims); // scalar ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - output_contents = outTensor.Data(); + output_contents = outTensor.ToArray(); EXPECT_EQ(-(7 + 2), output_contents[0]); // Clean up csession.CloseAndDelete(s); ASSERT_EQ(TF_Code.TF_OK, s.Code); - graph.Dispose(); - s.Dispose(); } [TestMethod] @@ -84,7 +82,7 @@ namespace TensorFlowNET.UnitTest var c = math_ops.matmul(a, b, name: "matmul"); using (var sess = tf.Session()) { - var result = c.eval(); + var result = c.eval(sess); Assert.AreEqual(6, result.Data()[0]); } } diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 848512f0..661d85ea 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -4,6 +4,12 @@ netcoreapp2.2 false + + true + + false + + Open.snk diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 07da9dca..11557f14 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); var tensor = new Tensor(nd); - var array = tensor.Data(); + var array = tensor.ToArray(); EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.rank, nd.ndim); diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 4c5ddd7a..4d9d1059 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; using Tensorflow; using static Tensorflow.Binding; @@ -16,14 +17,14 @@ namespace TensorFlowNET.UnitTest { session.run(x.initializer); var result = session.run(x); - Assert.AreEqual(10, (int)result[0]); + Assert.AreEqual(10, (int)result); } } [TestMethod] public void StringVar() { - var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars); + var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string); var mammal2 = tf.Variable("Tiger"); } @@ -81,7 +82,7 @@ namespace TensorFlowNET.UnitTest using (var session = tf.Session()) { session.run(model); - int result = session.run(y)[0]; + int result = session.run(y); Assert.AreEqual(result, 4); } } @@ -97,12 +98,12 @@ namespace TensorFlowNET.UnitTest var sess = tf.Session(graph); sess.run(init); - var result = sess.run(variable); - Assert.IsTrue((int)result[0] == 31); + NDArray result = sess.run(variable); + Assert.IsTrue((int)result == 31); var assign = variable.assign(12); result = sess.run(assign); - Assert.IsTrue((int)result[0] == 12); + Assert.IsTrue((int)result == 12); } [TestMethod] @@ -118,7 +119,7 @@ namespace TensorFlowNET.UnitTest { sess.run(init_op); // o some work with the model. - inc_v1.op.run(); + inc_v1.op.run(session: sess); } } @@ -139,7 +140,7 @@ namespace TensorFlowNET.UnitTest for(int i = 0; i < 5; i++) { x = x + 1; - result = session.run(x)[0]; + result = session.run(x); print(result); } } diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 1b6909e7..627d7c2f 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -1,4 +1,6 @@ -using Tensorflow; +using System.Diagnostics.CodeAnalysis; +using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest @@ -26,12 +28,15 @@ namespace TensorFlowNET.UnitTest return op; } + [SuppressMessage("ReSharper", "RedundantAssignment")] public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) { - var buffer = new Buffer(); - c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); - attr_value = AttrValue.Parser.ParseFrom(buffer); - buffer.Dispose(); + using (var buffer = new Buffer()) + { + c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); + attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } + return s.Code == TF_Code.TF_OK; } @@ -42,7 +47,7 @@ namespace TensorFlowNET.UnitTest { c_api.TF_GraphToGraphDef(graph, buffer, s); s.Check(); - return GraphDef.Parser.ParseFrom(buffer); + return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs index 4e82a2b6..72dd83ea 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs @@ -11,6 +11,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test /// /// https://www.tensorflow.org/api_docs/python/tf/while_loop /// + [Ignore] [TestMethod] public void SimpleWhileLoop() { diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs index 1fd7d3aa..3a5515d9 100644 --- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs +++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs @@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test var y_np = this._ZeroFraction(x_np); var x_tf = constant_op.constant(x_np); - x_tf.SetShape(x_shape); + x_tf.set_shape(x_shape); var y_tf = nn_impl.zero_fraction(x_tf); var y_tf_np = self.evaluate(y_tf); diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs index cdbd5f14..310ac634 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs @@ -24,16 +24,17 @@ namespace TensorFlowNET.UnitTest.ops_test [TestMethod] public void TestShape() { - var g = tf.Graph().as_default(); - - var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); - var op = g._create_op_from_tf_operation(c_op); - - Assert.AreEqual("myop", op.name); - Assert.AreEqual("Identity", op.type); - Assert.AreEqual(1, len(op.outputs)); - assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); + using (var g = tf.Graph().as_default()) + { + var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); + var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); + var op = g._create_op_from_tf_operation(c_op); + + Assert.AreEqual("myop", op.name); + Assert.AreEqual("Identity", op.type); + Assert.AreEqual(1, len(op.outputs)); + assertItemsEqual(new[] {2, 3}, op.outputs[0].shape); + } } [TestMethod]