| @@ -1,6 +1,6 @@ | |||
|  | |||
| **TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||
| **TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||
| [](https://gitter.im/sci-sharp/community) | |||
| [](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | |||
| @@ -16,7 +16,7 @@ TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). <a | |||
| ### Why TensorFlow.NET ? | |||
| `SciSharp STACK`'s mission is to bring popular data science technology into the .NET world and to provide .NET developers with a powerful Machine Learning tool set without reinventing the wheel. Scince the APIs are kept as similar as possible you can immediately adapt any existing Tensorflow code in C# with a zero learning curve. Take a look at a comparison picture and see how comfortably a Tensorflow/Python script translates into a C# program with TensorFlow.NET. | |||
| `SciSharp STACK`'s mission is to bring popular data science technology into the .NET world and to provide .NET developers with a powerful Machine Learning tool set without reinventing the wheel. Since the APIs are kept as similar as possible you can immediately adapt any existing Tensorflow code in C# with a zero learning curve. Take a look at a comparison picture and see how comfortably a Tensorflow/Python script translates into a C# program with TensorFlow.NET. | |||
|  | |||
| @@ -34,40 +34,15 @@ PM> Install-Package TensorFlow.NET | |||
| ### Install tensorflow binary | |||
| ### For CPU version | |||
| PM> Install-Package SciSharp.TensorFlow.Redist | |||
| ### For GPU version (CUDA and cuDNN are required) | |||
| PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | |||
| ``` | |||
| Import TF.NET. | |||
| ```cs | |||
| using Tensorflow; | |||
| ``` | |||
| Add two constants: | |||
| ```cs | |||
| // Create a Constant op | |||
| var a = tf.constant(4.0f); | |||
| var b = tf.constant(5.0f); | |||
| var c = tf.add(a, b); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var o = sess.run(c); | |||
| } | |||
| ``` | |||
| Import TF.NET in your project. | |||
| Feed placeholder: | |||
| ```cs | |||
| // Create a placeholder op | |||
| var a = tf.placeholder(tf.float32); | |||
| var b = tf.placeholder(tf.float32); | |||
| var c = tf.add(a, b); | |||
| using(var sess = tf.Session()) | |||
| { | |||
| var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); | |||
| } | |||
| using static Tensorflow.Binding; | |||
| ``` | |||
| Linear Regression: | |||
| @@ -91,39 +66,40 @@ var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||
| var init = tf.global_variables_initializer(); | |||
| // Start training | |||
| with(tf.Session(), sess => | |||
| using(tf.Session()) | |||
| { | |||
| // Run the initializer | |||
| sess.run(init); | |||
| // Fit all training data | |||
| for (int epoch = 0; epoch < training_epochs; epoch++) | |||
| { | |||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | |||
| sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y)); | |||
| sess.run(optimizer, (X, x), (Y, y)); | |||
| // Display logs per epoch step | |||
| if ((epoch + 1) % display_step == 0) | |||
| { | |||
| var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); | |||
| var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | |||
| } | |||
| Console.WriteLine("Optimization Finished!"); | |||
| var training_cost = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); | |||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||
| // Testing example | |||
| var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), new FeedItem(Y, test_Y)); | |||
| Console.WriteLine($"Testing cost={testing_cost}"); | |||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||
| } | |||
| Console.WriteLine("Optimization Finished!"); | |||
| var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||
| // Testing example | |||
| var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | |||
| (X, test_X), (Y, test_Y)); | |||
| Console.WriteLine($"Testing cost={testing_cost}"); | |||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||
| return diff < 0.01; | |||
| }); | |||
| ``` | |||
| @@ -17,6 +17,6 @@ | |||
| <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="NumSharp" Version="0.20.0-alpha1" /> | |||
| <PackageReference Include="NumSharp" Version="0.20.0" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -59,6 +59,6 @@ namespace Tensorflow | |||
| } | |||
| [DllImport(TensorFlowLibName)] | |||
| public static unsafe extern IntPtr TF_Version(); | |||
| public static extern IntPtr TF_Version(); | |||
| } | |||
| } | |||
| @@ -14,11 +14,16 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using static Tensorflow.ops; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| { | |||
| public graph_util_impl graph_util => new graph_util_impl(); | |||
| public GraphKeys GraphKeys { get; } = new GraphKeys(); | |||
| public Graph get_default_graph() | |||
| { | |||
| return ops.get_default_graph(); | |||
| @@ -0,0 +1,61 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.IO; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| { | |||
| public image_internal image = new image_internal(); | |||
| public class image_internal | |||
| { | |||
| public Tensor decode_jpeg(Tensor contents, | |||
| int channels = 0, | |||
| int ratio = 1, | |||
| bool fancy_upscaling = true, | |||
| bool try_recover_truncated = false, | |||
| float acceptable_fraction = 1, | |||
| string dct_method = "", | |||
| string name = null) | |||
| => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||
| fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||
| acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||
| public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) | |||
| => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name); | |||
| public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) | |||
| => gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name); | |||
| public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, | |||
| string name = null, bool expand_animations = true) | |||
| => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, | |||
| name: name, expand_animations: expand_animations); | |||
| /// <summary> | |||
| /// Convenience function to check if the 'contents' encodes a JPEG image. | |||
| /// </summary> | |||
| /// <param name="contents"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor is_jpeg(Tensor contents, string name = null) | |||
| => image_ops_impl.is_jpeg(contents, name: name); | |||
| } | |||
| } | |||
| } | |||
| @@ -52,5 +52,13 @@ namespace Tensorflow | |||
| stddev: stddev, | |||
| seed: seed, | |||
| dtype: dtype); | |||
| public IInitializer random_normal_initializer(float mean = 0.0f, | |||
| float stddev = 1.0f, | |||
| int? seed = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid) => new RandomNormal(mean: mean, | |||
| stddev: stddev, | |||
| seed: seed, | |||
| dtype: dtype); | |||
| } | |||
| } | |||
| @@ -24,8 +24,6 @@ namespace Tensorflow | |||
| public GFile gfile = new GFile(); | |||
| public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | |||
| public gen_image_ops image => new gen_image_ops(); | |||
| public void import_graph_def(GraphDef graph_def, | |||
| Dictionary<string, Tensor> input_map = null, | |||
| string[] return_elements = null, | |||
| @@ -148,6 +148,24 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Local Response Normalization. | |||
| /// </summary> | |||
| /// <param name="input"></param> | |||
| /// <param name="depth_radius"></param> | |||
| /// <param name="bias"></param> | |||
| /// <param name="alpha"></param> | |||
| /// <param name="beta"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor lrn(Tensor input, int depth_radius = 5, int bias = 1, | |||
| int alpha = 1, float beta = 0.5f, string name = null) | |||
| => gen_nn_ops.local_response_normalization(input, depth_radius: depth_radius, bias: bias, | |||
| alpha: alpha, beta: beta, name: name); | |||
| public Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||
| => nn_ops.leaky_relu(features, alpha: alpha, name: name); | |||
| public rnn_cell_impl rnn_cell => new rnn_cell_impl(); | |||
| public Tensor softmax(Tensor logits, int axis = -1, string name = null) | |||
| @@ -33,5 +33,13 @@ namespace Tensorflow | |||
| /// <returns>The scope name.</returns> | |||
| public ops.NameScope name_scope(string name, string default_name = "", object values = null) | |||
| => new ops.NameScope(name, default_name, values); | |||
| /// <summary> | |||
| /// Does nothing. Only useful as a placeholder for control edges. | |||
| /// </summary> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor no_op(string name = null) | |||
| => gen_control_flow_ops.no_op(name: name); | |||
| } | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.IO; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| { | |||
| public strings_internal strings = new strings_internal(); | |||
| public class strings_internal | |||
| { | |||
| public Tensor substr(Tensor input, int pos, int len, | |||
| string name = null, string @uint = "BYTE") | |||
| => string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||
| } | |||
| } | |||
| } | |||
| @@ -31,6 +31,9 @@ namespace Tensorflow | |||
| public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") | |||
| => new AdamOptimizer(learning_rate, name: name); | |||
| public ExponentialMovingAverage ExponentialMovingAverage(float decay) | |||
| => new ExponentialMovingAverage(decay); | |||
| public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); | |||
| public string write_graph(Graph graph, string logdir, string name, bool as_text = true) | |||
| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -22,7 +23,7 @@ namespace Tensorflow | |||
| { | |||
| public VariableV1[] global_variables(string scope = null) | |||
| { | |||
| return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||
| return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||
| .ToArray(); | |||
| } | |||
| @@ -32,6 +33,14 @@ namespace Tensorflow | |||
| return variables.variables_initializer(g.ToArray()); | |||
| } | |||
| /// <summary> | |||
| /// Returns all variables created with `trainable=True`. | |||
| /// </summary> | |||
| /// <param name="scope"></param> | |||
| /// <returns></returns> | |||
| public VariableV1[] trainable_variables(string scope = null) | |||
| => (variables.trainable_variables() as List<VariableV1>).ToArray(); | |||
| public RefVariable get_variable(string name, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| @@ -0,0 +1,4 @@ | |||
| using System.Runtime.CompilerServices; | |||
| #if DEBUG | |||
| [assembly: InternalsVisibleTo("TensorFlowNET.UnitTest, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] | |||
| #endif | |||
| @@ -178,13 +178,18 @@ namespace Tensorflow | |||
| public static IEnumerable<(TKey, TValue)> enumerate<TKey, TValue>(KeyValuePair<TKey, TValue>[] values) | |||
| { | |||
| foreach (var item in values) | |||
| var len = values.Length; | |||
| for (var i = 0; i < len; i++) | |||
| { | |||
| var item = values[i]; | |||
| yield return (item.Key, item.Value); | |||
| } | |||
| } | |||
| public static IEnumerable<(int, T)> enumerate<T>(IList<T> values) | |||
| { | |||
| for (int i = 0; i < values.Count; i++) | |||
| var len = values.Count; | |||
| for (int i = 0; i < len; i++) | |||
| yield return (i, values[i]); | |||
| } | |||
| @@ -308,15 +313,14 @@ namespace Tensorflow | |||
| public static IEnumerable TupleToEnumerable(object tuple) | |||
| { | |||
| Type t = tuple.GetType(); | |||
| if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||
| if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||
| { | |||
| var flds = t.GetFields(); | |||
| for(int i = 0; i < flds.Length;i++) | |||
| for (int i = 0; i < flds.Length; i++) | |||
| { | |||
| yield return flds[i].GetValue(tuple); | |||
| } | |||
| } | |||
| else | |||
| } else | |||
| { | |||
| throw new System.Exception("Expected Tuple."); | |||
| } | |||
| @@ -329,12 +333,9 @@ namespace Tensorflow | |||
| public static bool isinstance(object Item1, object tuple) | |||
| { | |||
| var tup = TupleToEnumerable(tuple); | |||
| foreach(var t in tup) | |||
| { | |||
| if(isinstance(Item1, (Type)t)) | |||
| foreach (var t in TupleToEnumerable(tuple)) | |||
| if (isinstance(Item1, (Type) t)) | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } | |||
| @@ -15,58 +15,116 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using NumSharp.Backends.Unmanaged; | |||
| using static Tensorflow.c_api; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Represents a TF_Buffer that can be passed to Tensorflow. | |||
| /// </summary> | |||
| public class Buffer : DisposableObject | |||
| { | |||
| private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||
| private unsafe TF_Buffer buffer | |||
| { | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| get => *bufferptr; | |||
| } | |||
| private unsafe TF_Buffer* bufferptr | |||
| { | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| get => (TF_Buffer*) _handle; | |||
| } | |||
| public byte[] Data | |||
| /// <summary> | |||
| /// The memory block representing this buffer. | |||
| /// </summary> | |||
| /// <remarks>The deallocator is set to null.</remarks> | |||
| public UnmanagedMemoryBlock<byte> MemoryBlock | |||
| { | |||
| get | |||
| get | |||
| { | |||
| var data = new byte[buffer.length]; | |||
| if (data.Length > 0) | |||
| Marshal.Copy(buffer.data, data, 0, data.Length); | |||
| return data; | |||
| unsafe | |||
| { | |||
| EnsureNotDisposed(); | |||
| var buff = (TF_Buffer*) _handle; | |||
| return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length); | |||
| } | |||
| } | |||
| } | |||
| public int Length => (int)buffer.length; | |||
| public Buffer() | |||
| /// <summary> | |||
| /// The bytes length of this buffer. | |||
| /// </summary> | |||
| public ulong Length | |||
| { | |||
| _handle = c_api.TF_NewBuffer(); | |||
| get | |||
| { | |||
| EnsureNotDisposed(); | |||
| return buffer.length; | |||
| } | |||
| } | |||
| public Buffer(IntPtr handle) | |||
| public Buffer() => _handle = TF_NewBuffer(); | |||
| internal Buffer(IntPtr handle) | |||
| { | |||
| if (handle == IntPtr.Zero) | |||
| throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); | |||
| _handle = handle; | |||
| } | |||
| public Buffer(byte[] data) | |||
| { | |||
| var dst = Marshal.AllocHGlobal(data.Length); | |||
| Marshal.Copy(data, 0, dst, data.Length); | |||
| public Buffer(byte[] data) : this(_toBuffer(data)) | |||
| { } | |||
| _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); | |||
| private static IntPtr _toBuffer(byte[] data) | |||
| { | |||
| if (data == null) | |||
| throw new ArgumentNullException(nameof(data)); | |||
| Marshal.FreeHGlobal(dst); | |||
| unsafe | |||
| { | |||
| fixed (byte* src = data) | |||
| return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); | |||
| } | |||
| } | |||
| public static implicit operator IntPtr(Buffer buffer) | |||
| { | |||
| buffer.EnsureNotDisposed(); | |||
| return buffer._handle; | |||
| } | |||
| public static implicit operator byte[](Buffer buffer) | |||
| public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. | |||
| /// <summary> | |||
| /// Copies this buffer's contents onto a <see cref="byte"/> array. | |||
| /// </summary> | |||
| public byte[] ToArray() | |||
| { | |||
| return buffer.Data; | |||
| EnsureNotDisposed(); | |||
| unsafe | |||
| { | |||
| var len = buffer.length; | |||
| if (len == 0) | |||
| return Array.Empty<byte>(); | |||
| byte[] data = new byte[len]; | |||
| fixed (byte* dst = data) | |||
| System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); | |||
| return data; | |||
| } | |||
| } | |||
| protected override void DisposeUnManagedState(IntPtr handle) | |||
| => c_api.TF_DeleteBuffer(handle); | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| TF_DeleteBuffer(handle); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,8 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| @@ -26,52 +28,71 @@ namespace Tensorflow | |||
| public abstract class DisposableObject : IDisposable | |||
| { | |||
| protected IntPtr _handle; | |||
| protected bool _disposed; | |||
| protected DisposableObject() { } | |||
| [SuppressMessage("ReSharper", "UnusedMember.Global")] | |||
| protected DisposableObject() | |||
| { } | |||
| public DisposableObject(IntPtr handle) | |||
| { | |||
| _handle = handle; | |||
| } | |||
| protected DisposableObject(IntPtr handle) | |||
| => _handle = handle; | |||
| protected virtual void DisposeManagedState() | |||
| [SuppressMessage("ReSharper", "InvertIf")] | |||
| private void internal_dispose(bool disposing) | |||
| { | |||
| } | |||
| if (_disposed) | |||
| return; | |||
| protected abstract void DisposeUnManagedState(IntPtr handle); | |||
| _disposed = true; | |||
| protected virtual void Dispose(bool disposing) | |||
| { | |||
| //first handle managed, they might use the unmanaged resources. | |||
| if (disposing) | |||
| { | |||
| // free unmanaged resources (unmanaged objects) and override a finalizer below. | |||
| if (_handle != IntPtr.Zero) | |||
| { | |||
| // dispose managed state (managed objects). | |||
| DisposeManagedState(); | |||
| // set large fields to null. | |||
| DisposeUnManagedState(_handle); | |||
| // dispose managed state (managed objects). | |||
| DisposeManagedResources(); | |||
| _handle = IntPtr.Zero; | |||
| } | |||
| //free unmanaged memory | |||
| if (_handle != IntPtr.Zero) | |||
| { | |||
| DisposeUnmanagedResources(_handle); | |||
| _handle = IntPtr.Zero; | |||
| } | |||
| } | |||
| // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. | |||
| /// <summary> | |||
| /// Dispose any managed resources. | |||
| /// </summary> | |||
| /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||
| protected virtual void DisposeManagedResources() | |||
| { } | |||
| /// <summary> | |||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
| /// </summary> | |||
| protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||
| ~DisposableObject() | |||
| { | |||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||
| Dispose(false); | |||
| internal_dispose(false); | |||
| } | |||
| // This code added to correctly implement the disposable pattern. | |||
| public void Dispose() | |||
| { | |||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||
| Dispose(true); | |||
| // uncomment the following line if the finalizer is overridden above. | |||
| GC.SuppressFinalize(this); | |||
| lock(this) | |||
| { | |||
| internal_dispose(true); | |||
| GC.SuppressFinalize(this); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// If <see cref="_handle"/> is <see cref="IntPtr.Zero"/> then throws <see cref="ObjectDisposedException"/> | |||
| /// </summary> | |||
| /// <exception cref="ObjectDisposedException">When <see cref="_handle"/> is <see cref="IntPtr.Zero"/></exception> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| protected void EnsureNotDisposed() | |||
| { | |||
| if (_disposed) | |||
| throw new ObjectDisposedException($"Unable to access disposed object, Type: {GetType().Name}"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -2,12 +2,10 @@ | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public class Context : IDisposable | |||
| public class Context : DisposableObject | |||
| { | |||
| private IntPtr _handle; | |||
| public static int GRAPH_MODE = 0; | |||
| public static int EAGER_MODE = 1; | |||
| public const int GRAPH_MODE = 0; | |||
| public const int EAGER_MODE = 1; | |||
| public int default_execution_mode; | |||
| @@ -17,19 +15,16 @@ namespace Tensorflow.Eager | |||
| status.Check(true); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TFE_DeleteContext(_handle); | |||
| } | |||
| /// <summary> | |||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
| /// </summary> | |||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TFE_DeleteContext(_handle); | |||
| public bool executing_eagerly() | |||
| { | |||
| return false; | |||
| } | |||
| public static implicit operator IntPtr(Context ctx) | |||
| { | |||
| return ctx._handle; | |||
| } | |||
| public bool executing_eagerly() => false; | |||
| public static implicit operator IntPtr(Context ctx) | |||
| => ctx._handle; | |||
| } | |||
| } | |||
| @@ -1,24 +1,22 @@ | |||
| using System; | |||
| using System.IO; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public class ContextOptions : IDisposable | |||
| public class ContextOptions : DisposableObject | |||
| { | |||
| private IntPtr _handle; | |||
| public ContextOptions() : base(c_api.TFE_NewContextOptions()) | |||
| { } | |||
| public ContextOptions() | |||
| { | |||
| _handle = c_api.TFE_NewContextOptions(); | |||
| } | |||
| /// <summary> | |||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
| /// </summary> | |||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TFE_DeleteContextOptions(_handle); | |||
| public void Dispose() | |||
| { | |||
| c_api.TFE_DeleteContextOptions(_handle); | |||
| } | |||
| public static implicit operator IntPtr(ContextOptions opts) | |||
| { | |||
| return opts._handle; | |||
| } | |||
| public static implicit operator IntPtr(ContextOptions opts) | |||
| => opts._handle; | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class KeyError : Exception | |||
| public class KeyError : TensorflowException | |||
| { | |||
| public KeyError() : base() | |||
| { | |||
| @@ -2,7 +2,7 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class RuntimeError : Exception | |||
| public class RuntimeError : TensorflowException | |||
| { | |||
| public RuntimeError() : base() | |||
| { | |||
| @@ -0,0 +1,36 @@ | |||
| using System; | |||
| using System.Runtime.Serialization; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Serves as a base class to all exceptions of Tensorflow.NET. | |||
| /// </summary> | |||
| [Serializable] | |||
| public class TensorflowException : Exception | |||
| { | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class.</summary> | |||
| public TensorflowException() | |||
| { } | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with serialized data.</summary> | |||
| /// <param name="info">The <see cref="T:System.Runtime.Serialization.SerializationInfo"></see> that holds the serialized object data about the exception being thrown.</param> | |||
| /// <param name="context">The <see cref="T:System.Runtime.Serialization.StreamingContext"></see> that contains contextual information about the source or destination.</param> | |||
| /// <exception cref="T:System.ArgumentNullException">The <paramref name="info">info</paramref> parameter is null.</exception> | |||
| /// <exception cref="T:System.Runtime.Serialization.SerializationException">The class name is null or <see cref="P:System.Exception.HResult"></see> is zero (0).</exception> | |||
| protected TensorflowException(SerializationInfo info, StreamingContext context) : base(info, context) | |||
| { } | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message.</summary> | |||
| /// <param name="message">The message that describes the error.</param> | |||
| public TensorflowException(string message) : base(message) | |||
| { } | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message and a reference to the inner exception that is the cause of this exception.</summary> | |||
| /// <param name="message">The error message that explains the reason for the exception.</param> | |||
| /// <param name="innerException">The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified.</param> | |||
| public TensorflowException(string message, Exception innerException) : base(message, innerException) | |||
| { } | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class TypeError : Exception | |||
| public class TypeError : TensorflowException | |||
| { | |||
| public TypeError() : base() | |||
| { | |||
| @@ -2,7 +2,7 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class ValueError : Exception | |||
| public class ValueError : TensorflowException | |||
| { | |||
| public ValueError() : base() | |||
| { | |||
| @@ -6,10 +6,5 @@ | |||
| { | |||
| } | |||
| ~ScopedTFImportGraphDefOptions() | |||
| { | |||
| base.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -13,10 +13,5 @@ namespace Tensorflow.Framework.Models | |||
| { | |||
| } | |||
| ~ScopedTFImportGraphDefResults() | |||
| { | |||
| base.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -5,10 +5,5 @@ | |||
| public ScopedTFStatus() : base() | |||
| { | |||
| } | |||
| ~ScopedTFStatus() | |||
| { | |||
| base.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -95,7 +95,7 @@ namespace Tensorflow | |||
| break; | |||
| case KindOneofCase.BytesList: | |||
| //var proto_type = ops.get_collection_proto_type(key) | |||
| if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) | |||
| if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) | |||
| { | |||
| foreach (var value in col.Value.BytesList.Value) | |||
| { | |||
| @@ -146,7 +146,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| var variables = graph.get_collection<VariableV1>(ops.GraphKeys.GLOBAL_VARIABLES, | |||
| var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | |||
| scope: scope_to_prepend_to_names); | |||
| var var_list = new Dictionary<string, VariableV1>(); | |||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | |||
| @@ -180,7 +180,7 @@ namespace Tensorflow | |||
| var graph = ops.get_default_graph(); | |||
| var var_list = new Dictionary<string, RefVariable>(); | |||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||
| var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||
| if (variables != null) | |||
| { | |||
| @@ -15,6 +15,8 @@ | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -27,12 +29,12 @@ namespace Tensorflow | |||
| if(_registered_ops == null) | |||
| { | |||
| _registered_ops = new Dictionary<string, OpDef>(); | |||
| var handle = c_api.TF_GetAllOpList(); | |||
| var buffer = new Buffer(handle); | |||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||
| foreach (var op_def in op_list.Op) | |||
| _registered_ops[op_def.Name] = op_def; | |||
| using (var buffer = new Buffer(c_api.TF_GetAllOpList())) | |||
| { | |||
| var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| foreach (var op_def in op_list.Op) | |||
| _registered_ops[op_def.Name] = op_def; | |||
| } | |||
| } | |||
| return _registered_ops; | |||
| @@ -14,49 +14,62 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class DefaultGraphStack | |||
| /// <summary> | |||
| /// Serves as a stack for determining current default graph. | |||
| /// </summary> | |||
| public class DefaultGraphStack | |||
| { | |||
| List<StackModel> stack = new List<StackModel>(); | |||
| private readonly List<StackModel> _stack = new List<StackModel>(); | |||
| public void set_controller(Graph @default) | |||
| { | |||
| if (!stack.Exists(x => x.Graph == @default)) | |||
| stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||
| if (!_stack.Exists(x => x.Graph == @default)) | |||
| _stack.Add(new StackModel {Graph = @default, IsDefault = true}); | |||
| foreach (var s in stack) | |||
| foreach (var s in _stack) | |||
| s.IsDefault = s.Graph == @default; | |||
| } | |||
| public Graph get_controller() | |||
| { | |||
| if (stack.Count(x => x.IsDefault) == 0) | |||
| stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||
| if (_stack.Count(x => x.IsDefault) == 0) | |||
| _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); | |||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||
| { | |||
| var x = _stack[i]; | |||
| if (x.IsDefault) | |||
| return x.Graph; | |||
| } | |||
| return stack.Last(x => x.IsDefault).Graph; | |||
| throw new TensorflowException("Unable to find a default graph"); | |||
| } | |||
| public bool remove(Graph g) | |||
| { | |||
| var sm = stack.FirstOrDefault(x => x.Graph == g); | |||
| if (sm == null) return false; | |||
| return stack.Remove(sm); | |||
| if (_stack.Count == 0) | |||
| return false; | |||
| var sm = _stack.Find(model => model.Graph == g); | |||
| return sm != null && _stack.Remove(sm); | |||
| } | |||
| public void reset() | |||
| { | |||
| stack.Clear(); | |||
| _stack.Clear(); | |||
| } | |||
| } | |||
| public class StackModel | |||
| { | |||
| public Graph Graph { get; set; } | |||
| public bool IsDefault { get; set; } | |||
| private class StackModel | |||
| { | |||
| public Graph Graph { get; set; } | |||
| public bool IsDefault { get; set; } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Linq; | |||
| using Tensorflow.Operations; | |||
| @@ -66,8 +67,9 @@ namespace Tensorflow | |||
| /// within the context should have control dependencies on | |||
| /// `control_inputs`. | |||
| /// </summary> | |||
| [SuppressMessage("ReSharper", "CoVariantArrayConversion")] | |||
| public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
| => control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray()); | |||
| => control_dependencies((object[])control_inputs); | |||
| /// <summary> | |||
| /// Returns a context manager that specifies control dependencies. | |||
| @@ -14,6 +14,9 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.IO; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| @@ -23,21 +26,19 @@ namespace Tensorflow | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
| s.Check(true); | |||
| // var def = GraphDef.Parser.ParseFrom(buffer); | |||
| // buffer.Dispose(); | |||
| return buffer; | |||
| } | |||
| private GraphDef _as_graph_def(bool add_shapes = false) | |||
| { | |||
| var status = new Status(); | |||
| var buffer = ToGraphDef(status); | |||
| status.Check(true); | |||
| status.Dispose(); | |||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| GraphDef def; | |||
| using (var status = new Status()) | |||
| using (var buffer = ToGraphDef(status)) | |||
| { | |||
| status.Check(true); | |||
| def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| } | |||
| // Strip the experimental library field iff it's empty. | |||
| // if(def.Library.Function.Count == 0) | |||
| @@ -45,7 +46,7 @@ namespace Tensorflow | |||
| return def; | |||
| } | |||
| public GraphDef as_graph_def(bool add_shapes = false) | |||
| public GraphDef as_graph_def(bool add_shapes = false) | |||
| => _as_graph_def(add_shapes); | |||
| } | |||
| } | |||
| } | |||
| @@ -30,11 +30,10 @@ namespace Tensorflow | |||
| var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | |||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| { | |||
| var handle = return_output_handle + i * size; | |||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||
| } | |||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| return_outputs[i] = *(tf_output_ptr + i); | |||
| Marshal.FreeHGlobal(return_output_handle); | |||
| @@ -18,6 +18,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -30,7 +31,7 @@ namespace Tensorflow | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | |||
| return OpDef.Parser.ParseFrom(buffer.Data); | |||
| return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| } | |||
| } | |||
| @@ -39,16 +40,20 @@ namespace Tensorflow | |||
| return c_api.TF_NewOperation(_handle, opType, opName); | |||
| } | |||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||
| public Operation[] ReturnOperations(IntPtr results) | |||
| { | |||
| TF_Operation return_oper_handle = new TF_Operation(); | |||
| int num_return_opers = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | |||
| Operation[] return_opers = new Operation[num_return_opers]; | |||
| var tf_op_size = Marshal.SizeOf<TF_Operation>(); | |||
| for (int i = 0; i < num_return_opers; i++) | |||
| { | |||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||
| unsafe | |||
| { | |||
| var handle = return_oper_handle.node + tf_op_size * i; | |||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||
| } | |||
| } | |||
| return return_opers; | |||
| @@ -67,7 +72,7 @@ namespace Tensorflow | |||
| public ITensorOrOperation[] get_operations() | |||
| { | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| return _nodes_by_name.Values.ToArray(); | |||
| } | |||
| /// <summary> | |||
| @@ -81,7 +86,7 @@ namespace Tensorflow | |||
| public ITensorOrOperation _get_operation_by_name_unsafe(string name) | |||
| { | |||
| return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; | |||
| return _nodes_by_name.TryGetValue(name, out var val) ? val : null; | |||
| } | |||
| public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | |||
| @@ -23,57 +23,58 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| /* | |||
| A TensorFlow computation, represented as a dataflow graph. | |||
| A `Graph` contains a set of | |||
| `tf.Operation` objects, | |||
| which represent units of computation; and | |||
| `tf.Tensor` objects, which represent | |||
| the units of data that flow between operations. | |||
| A default `Graph` is always registered, and accessible by calling | |||
| `tf.get_default_graph`. | |||
| To add an operation to the default graph, simply call one of the functions | |||
| that defines a new `Operation`: | |||
| ```python | |||
| c = tf.constant(4.0) | |||
| assert c.graph is tf.get_default_graph() | |||
| ``` | |||
| Another typical usage involves the | |||
| `tf.Graph.as_default` | |||
| context manager, which overrides the current default graph for the | |||
| lifetime of the context: | |||
| ```python | |||
| g = tf.Graph() | |||
| with g.as_default(): | |||
| # Define operations and tensors in `g`. | |||
| c = tf.constant(30.0) | |||
| assert c.graph is g | |||
| ``` | |||
| Important note: This class *is not* thread-safe for graph construction. All | |||
| operations should be created from a single thread, or external | |||
| synchronization must be provided. Unless otherwise specified, all methods | |||
| are not thread-safe. | |||
| A `Graph` instance supports an arbitrary number of "collections" | |||
| that are identified by name. For convenience when building a large | |||
| graph, collections can store groups of related objects: for | |||
| example, the `tf.Variable` uses a collection (named | |||
| `tf.GraphKeys.GLOBAL_VARIABLES`) for | |||
| all variables that are created during the construction of a graph. The caller | |||
| may define additional collections by specifying a new name. | |||
| */ | |||
| /// <summary> | |||
| /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. | |||
| /// This leads to a low-level programming model in which you first define the dataflow graph, | |||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
| /// https://www.tensorflow.org/guide/graphs | |||
| /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. | |||
| /// This leads to a low-level programming model in which you first define the dataflow graph, | |||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
| /// </summary> | |||
| /* | |||
| A TensorFlow computation, represented as a dataflow graph. | |||
| A `Graph` contains a set of | |||
| `tf.Operation` objects, | |||
| which represent units of computation; and | |||
| `tf.Tensor` objects, which represent | |||
| the units of data that flow between operations. | |||
| A default `Graph` is always registered, and accessible by calling | |||
| `tf.get_default_graph`. | |||
| To add an operation to the default graph, simply call one of the functions | |||
| that defines a new `Operation`: | |||
| ```python | |||
| c = tf.constant(4.0) | |||
| assert c.graph is tf.get_default_graph() | |||
| ``` | |||
| Another typical usage involves the | |||
| `tf.Graph.as_default` | |||
| context manager, which overrides the current default graph for the | |||
| lifetime of the context: | |||
| ```python | |||
| g = tf.Graph() | |||
| with g.as_default(): | |||
| # Define operations and tensors in `g`. | |||
| c = tf.constant(30.0) | |||
| assert c.graph is g | |||
| ``` | |||
| Important note: This class *is not* thread-safe for graph construction. All | |||
| operations should be created from a single thread, or external | |||
| synchronization must be provided. Unless otherwise specified, all methods | |||
| are not thread-safe. | |||
| A `Graph` instance supports an arbitrary number of "collections" | |||
| that are identified by name. For convenience when building a large | |||
| graph, collections can store groups of related objects: for | |||
| example, the `tf.Variable` uses a collection (named | |||
| `tf.GraphKeys.GLOBAL_VARIABLES`) for | |||
| all variables that are created during the construction of a graph. The caller | |||
| may define additional collections by specifying a new name. | |||
| */ | |||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | |||
| public partial class Graph : DisposableObject, IEnumerable<Operation> | |||
| { | |||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
| @@ -368,7 +369,7 @@ namespace Tensorflow | |||
| var name_key = name.ToLower(); | |||
| int i = 0; | |||
| if (_names_in_use.ContainsKey(name_key)) | |||
| i = _names_in_use[name_key]; | |||
| i = _names_in_use[name_key]; | |||
| // Increment the number for "name_key". | |||
| if (mark_as_used) | |||
| _names_in_use[name_key] = i + 1; | |||
| @@ -398,13 +399,13 @@ namespace Tensorflow | |||
| int num_return_outputs = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | |||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| unsafe | |||
| { | |||
| var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i); | |||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| return_outputs[i] = *(tf_output_ptr + i); | |||
| return return_outputs; | |||
| } | |||
| return return_outputs; | |||
| } | |||
| public string[] get_all_collection_keys() | |||
| @@ -439,12 +440,12 @@ namespace Tensorflow | |||
| _unfetchable_ops.Add(op); | |||
| } | |||
| protected override void DisposeManagedState() | |||
| protected override void DisposeManagedResources() | |||
| { | |||
| ops.default_graph_stack.remove(this); | |||
| } | |||
| protected override void DisposeUnManagedState(IntPtr handle) | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| c_api.TF_DeleteGraph(handle); | |||
| } | |||
| @@ -496,11 +497,9 @@ namespace Tensorflow | |||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | |||
| => GetEnumerable().GetEnumerator(); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| => throw new NotImplementedException(); | |||
| public static implicit operator IntPtr(Graph graph) | |||
| { | |||
| return graph._handle; | |||
| @@ -20,7 +20,8 @@ namespace Tensorflow | |||
| { | |||
| public class ImportGraphDefOptions : DisposableObject | |||
| { | |||
| public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
| public int NumReturnOutputs | |||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
| public ImportGraphDefOptions() | |||
| { | |||
| @@ -37,7 +38,7 @@ namespace Tensorflow | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
| } | |||
| protected override void DisposeUnManagedState(IntPtr handle) | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TF_DeleteImportGraphDefOptions(handle); | |||
| public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | |||
| @@ -16,6 +16,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| namespace Tensorflow.IO | |||
| { | |||
| @@ -28,6 +29,9 @@ namespace Tensorflow.IO | |||
| /// <param name="in_order">Traverse in order if True, post order if False.</param> | |||
| public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) | |||
| { | |||
| if (!Directory.Exists(top)) | |||
| return Enumerable.Empty<(string, string[], string[])>(); | |||
| return walk_v2(top, in_order); | |||
| } | |||
| @@ -81,7 +81,7 @@ namespace Tensorflow.Layers | |||
| // Update global default collections. | |||
| _add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS }); | |||
| _add_elements_to_collection(_updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||
| return outputs; | |||
| } | |||
| @@ -152,9 +152,9 @@ namespace Tensorflow.Operations | |||
| public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | |||
| { | |||
| // Add the subgraph defined by fn() to the graph. | |||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| var original_result = fn(); | |||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| //TODO: port this chunck of missing code: | |||
| /* | |||
| @@ -191,9 +191,9 @@ namespace Tensorflow.Operations | |||
| public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | |||
| { | |||
| // Add the subgraph defined by fn() to the graph. | |||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| var original_result = fn(); | |||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| switch (original_result) | |||
| { | |||
| @@ -141,7 +141,7 @@ namespace Tensorflow.Operations | |||
| data, frame_name, is_constant, parallel_iterations, name: name); | |||
| if (use_input_shape) | |||
| result.SetShape(data.TensorShape); | |||
| result.set_shape(data.TensorShape); | |||
| return result; | |||
| } | |||
| @@ -195,7 +195,7 @@ namespace Tensorflow.Operations | |||
| // their associated TensorArrays for calling the body. | |||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
| var body_result = body(packed_vars_for_body[0]); | |||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| // Store body_result to keep track of TensorArrays returned by body | |||
| var original_body_result = new[] { body_result }; | |||
| @@ -0,0 +1,59 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations.Initializers | |||
| { | |||
| public class RandomNormal : IInitializer | |||
| { | |||
| private float mean; | |||
| private float stddev; | |||
| private int? seed; | |||
| private TF_DataType dtype; | |||
| public RandomNormal(float mean = 0.0f, | |||
| float stddev = 1.0f, | |||
| int? seed = null, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
| { | |||
| this.mean = mean; | |||
| this.stddev = stddev; | |||
| this.seed = seed; | |||
| this.dtype = dtype; | |||
| } | |||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = this.dtype; | |||
| return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed); | |||
| } | |||
| public object get_config() | |||
| { | |||
| return new | |||
| { | |||
| mean, | |||
| stddev, | |||
| seed, | |||
| dtype | |||
| }; | |||
| } | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| { | |||
| public class Util | |||
| { | |||
| public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES) | |||
| public static void add_loss(Tensor loss, string loss_collection = "losses") | |||
| { | |||
| if (!string.IsNullOrEmpty(loss_collection)) | |||
| ops.add_to_collection(loss_collection, loss); | |||
| @@ -22,7 +22,7 @@ namespace Tensorflow | |||
| public class LossesImpl | |||
| { | |||
| public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | |||
| string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||
| string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||
| { | |||
| return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | |||
| { | |||
| @@ -101,7 +101,7 @@ namespace Tensorflow | |||
| Tensor logits, | |||
| float weights = 1.0f, | |||
| string scope = null, | |||
| string loss_collection= ops.GraphKeys.LOSSES, | |||
| string loss_collection= "losses", | |||
| string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||
| { | |||
| return tf_with(ops.name_scope(scope, | |||
| @@ -181,6 +181,31 @@ namespace Tensorflow.Operations | |||
| return _op.outputs; | |||
| } | |||
| /// <summary> | |||
| /// Local Response Normalization. | |||
| /// </summary> | |||
| /// <param name="input"></param> | |||
| /// <param name="depth_radius"></param> | |||
| /// <param name="bias"></param> | |||
| /// <param name="alpha"></param> | |||
| /// <param name="beta"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor local_response_normalization(Tensor input, int depth_radius = 5, int bias = 1, | |||
| int alpha = 1, float beta = 0.5f, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("LRN", name: name, args: new | |||
| { | |||
| input, | |||
| depth_radius, | |||
| bias, | |||
| alpha, | |||
| beta | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor log_softmax(Tensor logits, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("LogSoftmax", name: name, args: new | |||
| @@ -189,6 +214,17 @@ namespace Tensorflow.Operations | |||
| }); | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("LeakyRelu", name: name, args: new | |||
| { | |||
| features, | |||
| alpha | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor max_pool(Tensor input, | |||
| @@ -233,7 +233,7 @@ namespace Tensorflow.Operations | |||
| dims.AddRange(x_static_shape.dims.Skip(2)); | |||
| var shape = new TensorShape(dims.ToArray()); | |||
| x_t.SetShape(shape); | |||
| x_t.set_shape(shape); | |||
| return x_t; | |||
| } | |||
| @@ -50,14 +50,12 @@ namespace Tensorflow | |||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | |||
| { | |||
| int size = Marshal.SizeOf<TF_Input>(); | |||
| var handle = Marshal.AllocHGlobal(size); | |||
| var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Input>()); | |||
| int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | |||
| var consumers = new TF_Input[num]; | |||
| var inputptr = (TF_Input*) handle; | |||
| for (int i = 0; i < num; i++) | |||
| { | |||
| consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||
| } | |||
| consumers[i] = *(inputptr + i); | |||
| return consumers; | |||
| } | |||
| @@ -17,7 +17,9 @@ | |||
| using Google.Protobuf.Collections; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -226,9 +228,12 @@ namespace Tensorflow | |||
| using (var status = new Status()) | |||
| using (var buf = new Buffer()) | |||
| { | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
| status.Check(true); | |||
| x = AttrValue.Parser.ParseFrom(buf); | |||
| unsafe | |||
| { | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
| status.Check(true); | |||
| x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||
| } | |||
| } | |||
| string oneof_value = x.ValueCase.ToString(); | |||
| @@ -259,7 +264,7 @@ namespace Tensorflow | |||
| { | |||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||
| s.Check(); | |||
| return NodeDef.Parser.ParseFrom(buffer); | |||
| return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| } | |||
| } | |||
| @@ -299,8 +304,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public TF_Output _tf_output(int output_idx) | |||
| { | |||
| var tf_output = new TF_Output(op, output_idx); | |||
| return tf_output; | |||
| return new TF_Output(op, output_idx); | |||
| } | |||
| /// <summary> | |||
| @@ -308,8 +312,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public TF_Input _tf_input(int input_idx) | |||
| { | |||
| var tf_input = new TF_Input(op, input_idx); | |||
| return tf_input; | |||
| return new TF_Input(op, input_idx); | |||
| } | |||
| } | |||
| } | |||
| @@ -260,8 +260,7 @@ namespace Tensorflow | |||
| return tf_with(ops.name_scope(name, "ones", new { dims }), scope => | |||
| { | |||
| name = scope; | |||
| var shape = ops.convert_to_tensor(dims, dtype: TF_DataType.TF_INT32); | |||
| var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); | |||
| var output = _constant_if_small(1, dims, dtype, name); | |||
| return output; | |||
| }); | |||
| } | |||
| @@ -351,7 +350,7 @@ namespace Tensorflow | |||
| var input_shape = tensor_util.to_shape(input_tensor.shape); | |||
| if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | |||
| { | |||
| var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); | |||
| var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); | |||
| return constant_op.constant(nd, name: name); | |||
| } | |||
| } | |||
| @@ -431,8 +431,8 @@ namespace Tensorflow | |||
| merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); | |||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
| return merges[0]; | |||
| }); | |||
| @@ -479,8 +479,8 @@ namespace Tensorflow | |||
| merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | |||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
| return merges; | |||
| }); | |||
| @@ -596,7 +596,7 @@ namespace Tensorflow | |||
| swap_memory: swap_memory); | |||
| if (loop_context.outer_context == null) | |||
| ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); | |||
| ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | |||
| var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||
| return_same_structure); | |||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name= null) | |||
| public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name= null) | |||
| { | |||
| if (dtype == image.dtype) | |||
| return array_ops.identity(image, name: name); | |||
| @@ -57,7 +57,7 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public Tensor decode_jpeg(Tensor contents, | |||
| public static Tensor decode_jpeg(Tensor contents, | |||
| int channels = 0, | |||
| int ratio = 1, | |||
| bool fancy_upscaling = true, | |||
| @@ -88,7 +88,70 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) | |||
| public static Tensor decode_gif(Tensor contents, | |||
| string name = null) | |||
| { | |||
| // Add nodes to the TensorFlow graph. | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| throw new NotImplementedException("decode_gif"); | |||
| } | |||
| else | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("DecodeGif", name: name, args: new | |||
| { | |||
| contents | |||
| }); | |||
| return _op.output; | |||
| } | |||
| } | |||
| public static Tensor decode_png(Tensor contents, | |||
| int channels = 0, | |||
| TF_DataType dtype = TF_DataType.TF_UINT8, | |||
| string name = null) | |||
| { | |||
| // Add nodes to the TensorFlow graph. | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| throw new NotImplementedException("decode_png"); | |||
| } | |||
| else | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("DecodePng", name: name, args: new | |||
| { | |||
| contents, | |||
| channels, | |||
| dtype | |||
| }); | |||
| return _op.output; | |||
| } | |||
| } | |||
| public static Tensor decode_bmp(Tensor contents, | |||
| int channels = 0, | |||
| string name = null) | |||
| { | |||
| // Add nodes to the TensorFlow graph. | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| throw new NotImplementedException("decode_bmp"); | |||
| } | |||
| else | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("DecodeBmp", name: name, args: new | |||
| { | |||
| contents, | |||
| channels | |||
| }); | |||
| return _op.output; | |||
| } | |||
| } | |||
| public static Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| @@ -141,7 +141,7 @@ namespace Tensorflow | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | |||
| return _op.outputs[0]; | |||
| return _op.output; | |||
| } | |||
| public static Tensor atan(Tensor x, string name = null) | |||
| @@ -40,7 +40,7 @@ namespace Tensorflow | |||
| name: name, | |||
| args: new { shape, dtype, seed, seed2 }); | |||
| return _op.outputs[0]; | |||
| return _op.output; | |||
| } | |||
| /// <summary> | |||
| @@ -0,0 +1,42 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class gen_string_ops | |||
| { | |||
| static readonly OpDefLibrary _op_def_lib; | |||
| static gen_string_ops() { _op_def_lib = new OpDefLibrary(); } | |||
| public static Tensor substr(Tensor input, int pos, int len, | |||
| string name = null, string @uint = "BYTE") | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Substr", name: name, args: new | |||
| { | |||
| input, | |||
| pos, | |||
| len, | |||
| unit = @uint | |||
| }); | |||
| return _op.output; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,132 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class image_ops_impl | |||
| { | |||
| public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, | |||
| string name = null, bool expand_animations = true) | |||
| { | |||
| Tensor substr = null; | |||
| Func<ITensorOrOperation> _jpeg = () => | |||
| { | |||
| int jpeg_channels = channels; | |||
| var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels"); | |||
| string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'"; | |||
| var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
| return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate | |||
| { | |||
| return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype); | |||
| }); | |||
| }; | |||
| Func<ITensorOrOperation> _gif = () => | |||
| { | |||
| int gif_channels = channels; | |||
| var good_channels = math_ops.logical_and( | |||
| math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"), | |||
| math_ops.not_equal(gif_channels, 4, name: "check_gif_channels")); | |||
| string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images"; | |||
| var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
| return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate | |||
| { | |||
| var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype); | |||
| if (!expand_animations) | |||
| // result = array_ops.gather(result, 0); | |||
| throw new NotImplementedException(""); | |||
| return result; | |||
| }); | |||
| }; | |||
| Func<ITensorOrOperation> _bmp = () => | |||
| { | |||
| int bmp_channels = channels; | |||
| var signature = string_ops.substr(contents, 0, 2); | |||
| var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | |||
| string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | |||
| var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | |||
| var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels"); | |||
| string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images"; | |||
| var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
| return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate | |||
| { | |||
| return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype); | |||
| }); | |||
| }; | |||
| Func<ITensorOrOperation> _png = () => | |||
| { | |||
| return convert_image_dtype(gen_image_ops.decode_png( | |||
| contents, | |||
| channels, | |||
| dtype: dtype), | |||
| dtype); | |||
| }; | |||
| Func<ITensorOrOperation> check_gif = () => | |||
| { | |||
| var is_gif = math_ops.equal(substr, "\x47\x49\x46", name: "is_gif"); | |||
| return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif"); | |||
| }; | |||
| Func<ITensorOrOperation> check_png = () => | |||
| { | |||
| return control_flow_ops.cond(_is_png(contents), _png, check_gif, name: "cond_png"); | |||
| }; | |||
| return tf_with(ops.name_scope(name, "decode_image"), scope => | |||
| { | |||
| substr = string_ops.substr(contents, 0, 3); | |||
| return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | |||
| }); | |||
| } | |||
| public static Tensor is_jpeg(Tensor contents, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "is_jpeg"), scope => | |||
| { | |||
| var substr = string_ops.substr(contents, 0, 3); | |||
| return math_ops.equal(substr, "\xff\xd8\xff", name: name); | |||
| }); | |||
| } | |||
| public static Tensor _is_png(Tensor contents, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "is_png"), scope => | |||
| { | |||
| var substr = string_ops.substr(contents, 0, 3); | |||
| return math_ops.equal(substr, @"\211PN", name: name); | |||
| }); | |||
| } | |||
| public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, | |||
| string name = null) | |||
| { | |||
| if (dtype == image.dtype) | |||
| return array_ops.identity(image, name: name); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -168,6 +168,9 @@ namespace Tensorflow | |||
| public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.mul(x, y, name: name); | |||
| public static Tensor not_equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.not_equal(x, y, name: name); | |||
| public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.mul_no_nan(x, y, name: name); | |||
| @@ -264,6 +267,9 @@ namespace Tensorflow | |||
| return gen_math_ops.log(x, name); | |||
| } | |||
| public static Tensor logical_and(Tensor x, Tensor y, string name = null) | |||
| => gen_math_ops.logical_and(x, y, name: name); | |||
| public static Tensor lgamma(Tensor x, string name = null) | |||
| => gen_math_ops.lgamma(x, name: name); | |||
| @@ -98,7 +98,7 @@ namespace Tensorflow | |||
| // float to be selected, hence we use a >= comparison. | |||
| var keep_mask = random_tensor >= rate; | |||
| var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | |||
| ret.SetShape(x.TensorShape); | |||
| ret.set_shape(x.TensorShape); | |||
| return ret; | |||
| }); | |||
| } | |||
| @@ -116,6 +116,19 @@ namespace Tensorflow | |||
| return _softmax(logits, gen_nn_ops.log_softmax, axis, name); | |||
| } | |||
| public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "LeakyRelu", new { features, alpha }), scope => | |||
| { | |||
| name = scope; | |||
| features = ops.convert_to_tensor(features, name: "features"); | |||
| if (features.dtype.is_integer()) | |||
| features = math_ops.cast(features, dtypes.float32); | |||
| return gen_nn_ops.leaky_relu(features, alpha: alpha, name: name); | |||
| //return math_ops.maximum(alpha * features, features, name: name); | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Performs the max pooling on the input. | |||
| /// </summary> | |||
| @@ -39,9 +39,10 @@ namespace Tensorflow | |||
| { | |||
| return tf_with(ops.name_scope(name, "random_normal", new { shape, mean, stddev }), scope => | |||
| { | |||
| name = scope; | |||
| var shape_tensor = _ShapeTensor(shape); | |||
| var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); | |||
| var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev"); | |||
| var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev"); | |||
| var (seed1, seed2) = random_seed.get_seed(seed); | |||
| var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | |||
| var mul = rnd * stddev_tensor; | |||
| @@ -0,0 +1,38 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class string_ops | |||
| { | |||
| /// <summary> | |||
| /// Return substrings from `Tensor` of strings. | |||
| /// </summary> | |||
| /// <param name="input"></param> | |||
| /// <param name="pos"></param> | |||
| /// <param name="len"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="uint"></param> | |||
| /// <returns></returns> | |||
| public static Tensor substr(Tensor input, int pos, int len, | |||
| string name = null, string @uint = "BYTE") | |||
| => gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||
| } | |||
| } | |||
| @@ -1,408 +1,413 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Numerics; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class BaseSession : DisposableObject | |||
| { | |||
| protected Graph _graph; | |||
| protected bool _opened; | |||
| protected bool _closed; | |||
| protected int _current_version; | |||
| protected byte[] _target; | |||
| public Graph graph => _graph; | |||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
| { | |||
| _graph = g is null ? ops.get_default_graph() : g; | |||
| _graph.as_default(); | |||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||
| SessionOptions newOpts = null; | |||
| if (opts == null) | |||
| newOpts = new SessionOptions(); | |||
| var status = new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
| // dispose newOpts | |||
| if (opts == null) | |||
| c_api.TF_DeleteSessionOptions(newOpts); | |||
| status.Check(true); | |||
| } | |||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
| { | |||
| _run(op, feed_dict); | |||
| } | |||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
| return (results[0], results[1], results[2]); | |||
| } | |||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
| return (results[0], results[1]); | |||
| } | |||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
| { | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| return _run(fetches, feed_items); | |||
| } | |||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| var feed_map = new Dictionary<object, object>(); | |||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||
| { | |||
| return new (object, object)[] { (item.Key, item.Value) }; | |||
| }; | |||
| // Validate and process feed_dict. | |||
| if (feed_dict != null) | |||
| { | |||
| foreach (var feed in feed_dict) | |||
| { | |||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
| } | |||
| } | |||
| } | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| // These movers are no longer needed when _do_run() completes, and | |||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||
| var _ = _update_with_movers(); | |||
| var final_fetches = fetch_handler.fetches(); | |||
| var final_targets = fetch_handler.targets(); | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
| return fetch_handler.build_results(this, results); | |||
| } | |||
| /// <summary> | |||
| /// Runs a step based on the given fetches and feeds. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
| /// <param name="fetch_list"></param> | |||
| /// <param name="feed_dict"></param> | |||
| /// <returns> | |||
| /// A list of numpy ndarrays, corresponding to the elements of | |||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||
| /// name of an operation, the first Tensor output of that operation | |||
| /// will be returned for that element. | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| { | |||
| var feeds = feed_dict.Select(x => | |||
| { | |||
| if (x.Key is Tensor tensor) | |||
| { | |||
| switch (x.Value) | |||
| { | |||
| #if _REGEN | |||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case #1[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| % | |||
| #else | |||
| case sbyte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case sbyte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| #endif | |||
| case bool v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||
| case string v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case IntPtr v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Tensor v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||
| case NDArray v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||
| default: | |||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||
| } | |||
| } | |||
| throw new NotImplementedException("_do_run.feed_dict"); | |||
| }).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| var targets = target_list; | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| } | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
| { | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| var status = new Status(); | |||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||
| ninputs: feed_dict.Length, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: fetch_list.Length, | |||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
| ntargets: target_list.Count, | |||
| run_metadata: IntPtr.Zero, | |||
| status: status); | |||
| status.Check(true); | |||
| var result = new NDArray[fetch_list.Length]; | |||
| for (int i = 0; i < fetch_list.Length; i++) | |||
| result[i] = fetchValue(output_values[i]); | |||
| for (int i = 0; i < feed_dict.Length; i++) | |||
| feed_dict[i].Value.Dispose(); | |||
| return result; | |||
| } | |||
| private unsafe NDArray fetchValue(IntPtr output) | |||
| { | |||
| var tensor = new Tensor(output); | |||
| NDArray nd = null; | |||
| Type type = tensor.dtype.as_numpy_datatype(); | |||
| var ndims = tensor.shape; | |||
| var offset = c_api.TF_TensorData(output); | |||
| if(ndims.Length == 0) | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| nd = NDArray.Scalar(*(bool*)offset); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.Data(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = NDArray.FromString(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| nd = NDArray.Scalar(*(byte*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| nd = NDArray.Scalar(*(short*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| nd = NDArray.Scalar(*(int*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| nd = NDArray.Scalar(*(long*)offset); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| nd = NDArray.Scalar(*(float*)offset); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| nd = NDArray.Scalar(*(double*)offset); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| var bools = new bool[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(bools).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.Data(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = np.array(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| var _bytes = new byte[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(_bytes).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| var shorts = new short[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(shorts).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| var ints = new int[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(ints).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| var longs = new long[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(longs).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| var floats = new float[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(floats).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| var doubles = new double[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(doubles).reshape(ndims); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| tensor.Dispose(); | |||
| return nd; | |||
| } | |||
| /// <summary> | |||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||
| /// we move the tensor to the right device, generate a new tensor handle, | |||
| /// and update feed_dict to use the new handle. | |||
| /// </summary> | |||
| private List<object> _update_with_movers() | |||
| { | |||
| return new List<object> { }; | |||
| } | |||
| private void _extend_graph() | |||
| { | |||
| } | |||
| public void close() | |||
| { | |||
| Dispose(); | |||
| } | |||
| protected override void DisposeUnManagedState(IntPtr handle) | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_DeleteSession(handle, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Numerics; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class BaseSession : DisposableObject | |||
| { | |||
| protected Graph _graph; | |||
| protected bool _opened; | |||
| protected bool _closed; | |||
| protected int _current_version; | |||
| protected byte[] _target; | |||
| public Graph graph => _graph; | |||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
| { | |||
| _graph = g is null ? ops.get_default_graph() : g; | |||
| _graph.as_default(); | |||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||
| SessionOptions newOpts = null; | |||
| if (opts == null) | |||
| newOpts = new SessionOptions(); | |||
| var status = new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
| // dispose newOpts | |||
| if (opts == null) | |||
| newOpts.Dispose(); | |||
| status.Check(true); | |||
| } | |||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
| { | |||
| _run(op, feed_dict); | |||
| } | |||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
| return (results[0], results[1], results[2]); | |||
| } | |||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
| return (results[0], results[1]); | |||
| } | |||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
| { | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| return _run(fetches, feed_items); | |||
| } | |||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| var feed_map = new Dictionary<object, object>(); | |||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||
| { | |||
| return new (object, object)[] { (item.Key, item.Value) }; | |||
| }; | |||
| // Validate and process feed_dict. | |||
| if (feed_dict != null) | |||
| { | |||
| foreach (var feed in feed_dict) | |||
| { | |||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
| } | |||
| } | |||
| } | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| // These movers are no longer needed when _do_run() completes, and | |||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||
| var _ = _update_with_movers(); | |||
| var final_fetches = fetch_handler.fetches(); | |||
| var final_targets = fetch_handler.targets(); | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
| return fetch_handler.build_results(this, results); | |||
| } | |||
| /// <summary> | |||
| /// Runs a step based on the given fetches and feeds. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
| /// <param name="fetch_list"></param> | |||
| /// <param name="feed_dict"></param> | |||
| /// <returns> | |||
| /// A list of numpy ndarrays, corresponding to the elements of | |||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||
| /// name of an operation, the first Tensor output of that operation | |||
| /// will be returned for that element. | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| { | |||
| var feeds = feed_dict.Select(x => | |||
| { | |||
| if (x.Key is Tensor tensor) | |||
| { | |||
| switch (x.Value) | |||
| { | |||
| #if _REGEN | |||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case #1[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| % | |||
| #else | |||
| case sbyte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case sbyte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| #endif | |||
| case bool v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||
| case string v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case IntPtr v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Tensor v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||
| case NDArray v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||
| default: | |||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||
| } | |||
| } | |||
| throw new NotImplementedException("_do_run.feed_dict"); | |||
| }).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| var targets = target_list; | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| } | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
| { | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| var status = new Status(); | |||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||
| ninputs: feed_dict.Length, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: fetch_list.Length, | |||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
| ntargets: target_list.Count, | |||
| run_metadata: IntPtr.Zero, | |||
| status: status); | |||
| status.Check(true); | |||
| var result = new NDArray[fetch_list.Length]; | |||
| for (int i = 0; i < fetch_list.Length; i++) | |||
| result[i] = fetchValue(output_values[i]); | |||
| for (int i = 0; i < feed_dict.Length; i++) | |||
| feed_dict[i].Value.Dispose(); | |||
| return result; | |||
| } | |||
| private unsafe NDArray fetchValue(IntPtr output) | |||
| { | |||
| var tensor = new Tensor(output); | |||
| NDArray nd = null; | |||
| Type type = tensor.dtype.as_numpy_dtype(); | |||
| var ndims = tensor.shape; | |||
| var offset = c_api.TF_TensorData(output); | |||
| if(ndims.Length == 0) | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| nd = NDArray.Scalar(*(bool*)offset); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = NDArray.FromString(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| nd = NDArray.Scalar(*(byte*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| nd = NDArray.Scalar(*(short*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| nd = NDArray.Scalar(*(int*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| nd = NDArray.Scalar(*(long*)offset); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| nd = NDArray.Scalar(*(float*)offset); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| nd = NDArray.Scalar(*(double*)offset); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| var bools = new bool[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(bools).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = np.array(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| var _bytes = new byte[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(_bytes).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| var shorts = new short[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(shorts).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| var ints = new int[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(ints).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| var longs = new long[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(longs).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| var floats = new float[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(floats).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| var doubles = new double[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(doubles).reshape(ndims); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| tensor.Dispose(); | |||
| return nd; | |||
| } | |||
| /// <summary> | |||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||
| /// we move the tensor to the right device, generate a new tensor handle, | |||
| /// and update feed_dict to use the new handle. | |||
| /// </summary> | |||
| private List<object> _update_with_movers() | |||
| { | |||
| return new List<object> { }; | |||
| } | |||
| private void _extend_graph() | |||
| { | |||
| } | |||
| public void close() | |||
| { | |||
| Dispose(); | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_DeleteSession(handle, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -16,5 +16,11 @@ | |||
| public static implicit operator FeedItem((object, object) feed) | |||
| => new FeedItem(feed.Item1, feed.Item2); | |||
| public void Deconstruct(out object key, out object value) | |||
| { | |||
| key = Key; | |||
| value = Value; | |||
| } | |||
| } | |||
| } | |||
| @@ -32,13 +32,13 @@ namespace Tensorflow | |||
| _handle = handle; | |||
| } | |||
| protected override void DisposeUnManagedState(IntPtr handle) | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TF_DeleteSessionOptions(handle); | |||
| public void SetConfig(ConfigProto config) | |||
| { | |||
| var bytes = config.ToByteArray(); | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||
| var bytes = config.ToByteArray(); //TODO! we can use WriteTo | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak | |||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
| using (var status = new Status()) | |||
| @@ -17,6 +17,7 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using NumSharp.Backends; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -71,18 +72,18 @@ namespace Tensorflow | |||
| { | |||
| if(tensor_values.Length > 0) | |||
| { | |||
| switch (tensor_values[0].dtype.Name) | |||
| switch (tensor_values[0].typecode) | |||
| { | |||
| case "Int32": | |||
| case NPTypeCode.Int32: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| case "Single": | |||
| case NPTypeCode.Single: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| case "String": | |||
| case NPTypeCode.String: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| case "Char": | |||
| case NPTypeCode.Char: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| default: | |||
| @@ -100,21 +101,21 @@ namespace Tensorflow | |||
| j += 1; | |||
| if (value.ndim == 0) | |||
| { | |||
| switch (value.dtype.Name) | |||
| switch (value.typecode) | |||
| { | |||
| case "Int16": | |||
| case NPTypeCode.Int16: | |||
| full_values.Add(value.GetValue<short>(0)); | |||
| break; | |||
| case "Int32": | |||
| case NPTypeCode.Int32: | |||
| full_values.Add(value.GetValue<int>(0)); | |||
| break; | |||
| case "Int64": | |||
| case NPTypeCode.Int64: | |||
| full_values.Add(value.GetValue<long>(0)); | |||
| break; | |||
| case "Single": | |||
| case NPTypeCode.Single: | |||
| full_values.Add(value.GetValue<float>(0)); | |||
| break; | |||
| case "Double": | |||
| case NPTypeCode.Double: | |||
| full_values.Add(value.GetValue<double>(0)); | |||
| break; | |||
| /*case "String": | |||
| @@ -27,13 +27,17 @@ namespace Tensorflow | |||
| var handle = Marshal.AllocHGlobal(size * num_consumers); | |||
| int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); | |||
| var consumers = new string[num_consumers]; | |||
| for (int i = 0; i < num; i++) | |||
| unsafe | |||
| { | |||
| TF_Input input = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||
| consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper)); | |||
| var inputptr = (TF_Input*) handle; | |||
| for (int i = 0; i < num; i++) | |||
| { | |||
| var oper = (inputptr + i)->oper; | |||
| consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); | |||
| } | |||
| } | |||
| return consumers; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -15,6 +15,8 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Runtime.CompilerServices; | |||
| using static Tensorflow.c_api; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -27,36 +29,36 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Error message | |||
| /// </summary> | |||
| public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); | |||
| public string Message => c_api.StringPiece(TF_Message(_handle)); | |||
| /// <summary> | |||
| /// Error code | |||
| /// </summary> | |||
| public TF_Code Code => c_api.TF_GetCode(_handle); | |||
| public TF_Code Code => TF_GetCode(_handle); | |||
| public Status() | |||
| { | |||
| _handle = c_api.TF_NewStatus(); | |||
| _handle = TF_NewStatus(); | |||
| } | |||
| public void SetStatus(TF_Code code, string msg) | |||
| { | |||
| c_api.TF_SetStatus(_handle, code, msg); | |||
| TF_SetStatus(_handle, code, msg); | |||
| } | |||
| /// <summary> | |||
| /// Check status | |||
| /// Throw exception with error message if code != TF_OK | |||
| /// </summary> | |||
| /// <exception cref="TensorflowException">When the returned check is not TF_Code.TF_OK</exception> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public void Check(bool throwException = false) | |||
| { | |||
| if(Code != TF_Code.TF_OK) | |||
| if (Code != TF_Code.TF_OK) | |||
| { | |||
| Console.WriteLine(Message); | |||
| if (throwException) | |||
| { | |||
| throw new Exception(Message); | |||
| } | |||
| throw new TensorflowException(Message); | |||
| } | |||
| } | |||
| @@ -65,7 +67,7 @@ namespace Tensorflow | |||
| return status._handle; | |||
| } | |||
| protected override void DisposeUnManagedState(IntPtr handle) | |||
| => c_api.TF_DeleteStatus(handle); | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => TF_DeleteStatus(handle); | |||
| } | |||
| } | |||
| } | |||
| @@ -51,7 +51,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static unsafe extern IntPtr TF_NewStatus(); | |||
| public static extern IntPtr TF_NewStatus(); | |||
| /// <summary> | |||
| /// Record <code, msg> in *s. Any previous information is lost. | |||
| @@ -33,11 +33,11 @@ namespace Tensorflow.Summaries | |||
| { | |||
| var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); | |||
| var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); | |||
| collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES }); | |||
| collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES }); | |||
| return val; | |||
| } | |||
| public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null) | |||
| public Tensor merge_all(string key = "summaries", string scope= null, string name= null) | |||
| { | |||
| var summary_ops = ops.get_collection(key, scope: scope); | |||
| if (summary_ops == null) | |||
| @@ -67,7 +67,7 @@ namespace Tensorflow.Summaries | |||
| { | |||
| var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); | |||
| var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); | |||
| collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES }); | |||
| collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES }); | |||
| return val; | |||
| } | |||
| @@ -5,8 +5,8 @@ | |||
| <AssemblyName>TensorFlow.NET</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <TargetTensorFlow>1.14.0</TargetTensorFlow> | |||
| <Version>0.11.0</Version> | |||
| <Authors>Haiping Chen, Meinrad Recheis</Authors> | |||
| <Version>0.11.1</Version> | |||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| <Copyright>Apache 2.0</Copyright> | |||
| @@ -17,10 +17,16 @@ | |||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | |||
| <Description>Google's TensorFlow full binding in .NET Standard. | |||
| Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.11.10.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.10.0:</PackageReleaseNotes> | |||
| <AssemblyVersion>0.11.1.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.10.0: | |||
| 1. Upgrade NumSharp to v0.20. | |||
| 2. Add DisposableObject class to manage object lifetime. | |||
| 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. | |||
| 4. Change tensorflow to non-static class in order to execute some initialization process. | |||
| 5. Overload session.run(), make syntax simpler. | |||
| 6. Add Local Response Normalization.</PackageReleaseNotes> | |||
| <LangVersion>7.3</LangVersion> | |||
| <FileVersion>0.11.10.0</FileVersion> | |||
| <FileVersion>0.11.1.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
| <SignAssembly>true</SignAssembly> | |||
| @@ -52,7 +58,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <ItemGroup> | |||
| <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | |||
| <PackageReference Include="NumSharp" Version="0.20.0-alpha1" /> | |||
| <PackageReference Include="NumSharp" Version="0.20.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -16,11 +16,13 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Linq; | |||
| using System.Numerics; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using NumSharp.Backends; | |||
| using NumSharp.Backends.Unmanaged; | |||
| using static Tensorflow.c_api; | |||
| @@ -50,9 +52,9 @@ namespace Tensorflow | |||
| private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; | |||
| // note: they must be assigned to a static variable in order to work as unmanaged callbacks | |||
| static Deallocator _hGlobalDeallocator = FreeHGlobalMemory; | |||
| static Deallocator _gcHandleDeallocator = FreeGCHandle; | |||
| private static Deallocator _nothingDeallocator = FreeNothing; | |||
| private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory; | |||
| private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle; | |||
| private static readonly Deallocator _nothingDeallocator = FreeNothing; | |||
| /// <summary> | |||
| /// Create a Tensor object from an existing TF handle | |||
| @@ -462,7 +464,7 @@ namespace Tensorflow | |||
| *v = value; | |||
| _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); | |||
| IsMemoryOwner=true; | |||
| } | |||
| } | |||
| #endif | |||
| /// <summary> | |||
| @@ -477,7 +479,7 @@ namespace Tensorflow | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| Marshal.WriteInt64(tensor, 0); | |||
| fixed (byte* src = &buffer[0]) | |||
| fixed (byte* src = buffer) | |||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||
| _handle = handle; | |||
| status.Check(true); | |||
| @@ -486,35 +488,54 @@ namespace Tensorflow | |||
| public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | |||
| { | |||
| // todo: handle nd of type "String" here too | |||
| if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") | |||
| if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) | |||
| { | |||
| var buffer = nd.ToArray<byte>(); | |||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| Marshal.WriteInt64(tensor, 0); | |||
| if (nd.Unsafe.Storage.Shape.IsContiguous) | |||
| { | |||
| var bytesLength = (UIntPtr)nd.size; | |||
| var size = c_api.TF_StringEncodedSize(bytesLength); | |||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| Marshal.WriteInt64(tensor, 0); | |||
| var status = new Status(); | |||
| c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||
| status.Check(true); | |||
| _handle = handle; | |||
| IsMemoryOwner = false; | |||
| } | |||
| else | |||
| { | |||
| var buffer = nd.ToArray<byte>(); | |||
| var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); | |||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| Marshal.WriteInt64(tensor, 0); | |||
| var status = new Status(); | |||
| fixed (byte* src = buffer) | |||
| c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||
| status.Check(true); | |||
| _handle = handle; | |||
| IsMemoryOwner = false; | |||
| } | |||
| var status = new Status(); | |||
| fixed (byte* src = &buffer[0]) | |||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||
| status.Check(true); | |||
| _handle=handle; | |||
| IsMemoryOwner = false; | |||
| return; | |||
| } | |||
| _handle = CreateTensorFromNDArray(nd, tensorDType); | |||
| IsMemoryOwner = true; | |||
| } | |||
| private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | |||
| { | |||
| if (nd.dtype.Name == "String") | |||
| throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||
| if (nd.dtype.Name == "String") | |||
| throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||
| IArraySlice arraySlice; | |||
| var shape = nd.Unsafe.Storage.Shape; | |||
| if (shape.IsSliced || shape.IsBroadcasted) | |||
| if (nd.Unsafe.Storage.Shape.IsContiguous == false) | |||
| { | |||
| // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. | |||
| arraySlice = nd.CloneData(); | |||
| @@ -527,51 +548,52 @@ namespace Tensorflow | |||
| this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | |||
| var ptr = new IntPtr(arraySlice.Address); | |||
| int num_bytes = (nd.size * nd.dtypesize); | |||
| var dtype = given_dtype ?? ToTFDataType(nd.dtype); | |||
| var dtype = given_dtype ?? nd.dtype.as_dtype(); | |||
| var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | |||
| IsMemoryOwner = false; | |||
| return handle; | |||
| } | |||
| public unsafe Tensor(byte[][] buffer, long[] shape) | |||
| { | |||
| int size = 0; | |||
| foreach (var b in buffer) | |||
| { | |||
| size += (int)TF_StringEncodedSize((UIntPtr)b.Length); | |||
| } | |||
| int totalSize = size + buffer.Length * 8; | |||
| ulong offset = 0; | |||
| IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); | |||
| // Clear offset table | |||
| IntPtr pOffset = TF_TensorData(handle); | |||
| IntPtr dst = pOffset + buffer.Length * 8; | |||
| IntPtr dstLimit = pOffset + totalSize; | |||
| for (int i = 0; i < buffer.Length; i++) | |||
| { | |||
| Marshal.WriteInt64(pOffset, (long)offset); | |||
| using (var status = new Status()) | |||
| { | |||
| fixed (byte* src = &buffer[i][0]) | |||
| { | |||
| var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); | |||
| status.Check(true); | |||
| pOffset += 8; | |||
| dst += (int)written; | |||
| offset += written; | |||
| } | |||
| } | |||
| } | |||
| _handle = handle; | |||
| } | |||
| public unsafe Tensor(byte[][] buffer, long[] shape) | |||
| { | |||
| int size = 0; | |||
| foreach (var b in buffer) | |||
| { | |||
| size += (int)TF_StringEncodedSize((UIntPtr)b.Length); | |||
| } | |||
| int totalSize = size + buffer.Length * 8; | |||
| ulong offset = 0; | |||
| IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); | |||
| // Clear offset table | |||
| IntPtr pOffset = TF_TensorData(handle); | |||
| IntPtr dst = pOffset + buffer.Length * 8; | |||
| IntPtr dstLimit = pOffset + totalSize; | |||
| for (int i = 0; i < buffer.Length; i++) | |||
| { | |||
| Marshal.WriteInt64(pOffset, (long)offset); | |||
| using (var status = new Status()) | |||
| { | |||
| fixed (byte* src = &buffer[i][0]) | |||
| { | |||
| var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); | |||
| status.Check(true); | |||
| pOffset += 8; | |||
| dst += (int)written; | |||
| offset += written; | |||
| } | |||
| } | |||
| } | |||
| _handle = handle; | |||
| } | |||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
| { | |||
| _op = op; | |||
| _value_index = value_index; | |||
| _dtype = dtype; | |||
| _override_dtype = dtype; | |||
| _id = ops.uid(); | |||
| } | |||
| @@ -589,11 +611,11 @@ namespace Tensorflow | |||
| /// specified dimensions. | |||
| /// </remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| [SuppressMessage("ReSharper", "LocalVariableHidesMember")] | |||
| protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) | |||
| { | |||
| if (dt == TF_DataType.TF_STRING && data is byte[]) | |||
| if (dt == TF_DataType.TF_STRING && data is byte[] buffer) | |||
| { | |||
| var buffer = (byte[])data; | |||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||
| @@ -601,7 +623,7 @@ namespace Tensorflow | |||
| Marshal.WriteInt64(tensor, 0); | |||
| var status = new Status(); | |||
| fixed (byte* src = &buffer[0]) | |||
| fixed (byte* src = buffer) | |||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||
| status.Check(true); | |||
| @@ -644,8 +666,9 @@ namespace Tensorflow | |||
| { | |||
| if (args.deallocator_called) | |||
| return; | |||
| // NumSharp will dispose | |||
| // Marshal.FreeHGlobal(dataPtr); | |||
| Marshal.FreeHGlobal(dataPtr); | |||
| args.deallocator_called = true; | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using System.Runtime.CompilerServices; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -6,86 +7,142 @@ namespace Tensorflow | |||
| { | |||
| public static explicit operator bool(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<bool>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_BOOL); | |||
| return *(bool*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator sbyte(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<sbyte>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_INT8); | |||
| return *(sbyte*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator byte(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<byte>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_UINT8); | |||
| return *(byte*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator ushort(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<ushort>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_UINT16); | |||
| return *(ushort*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator short(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<short>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_INT16); | |||
| return *(short*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator int(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<int>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_INT32); | |||
| return *(int*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator uint(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<uint>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_UINT32); | |||
| return *(uint*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator long(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<long>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_INT64); | |||
| return *(long*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator ulong(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<ulong>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_UINT64); | |||
| return *(ulong*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator float(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<float>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||
| return *(float*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator double(Tensor tensor) | |||
| { | |||
| EnsureScalar(tensor); | |||
| return tensor.Data<double>()[0]; | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||
| return *(double*) tensor.buffer; | |||
| } | |||
| } | |||
| public static explicit operator string(Tensor tensor) | |||
| { | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_STRING); | |||
| return new string((char*) tensor.buffer, 0, (int) tensor.size); | |||
| } | |||
| } | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| private static void EnsureDType(Tensor tensor, TF_DataType @is) | |||
| { | |||
| if (tensor.dtype != @is) | |||
| throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}"); | |||
| } | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| private static void EnsureScalar(Tensor tensor) | |||
| { | |||
| if (tensor == null) | |||
| { | |||
| throw new ArgumentNullException(nameof(tensor)); | |||
| } | |||
| if (tensor.TensorShape.ndim != 0) | |||
| { | |||
| throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | |||
| } | |||
| if (tensor.TensorShape.size != 1) | |||
| { | |||
| throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | |||
| } | |||
| } | |||
| } | |||
| @@ -69,11 +69,12 @@ namespace Tensorflow | |||
| TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, | |||
| TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | |||
| }; | |||
| public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); | |||
| public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); | |||
| public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); | |||
| public static Tensor operator /(Tensor x, Tensor y) => | |||
| _intTfDataTypes.Contains(x._dtype) | |||
| _intTfDataTypes.Contains(x.dtype) | |||
| ? BinaryOpWrapper("floordiv", x, y) | |||
| : BinaryOpWrapper("truediv", x, y); | |||
| public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); | |||
| @@ -122,8 +123,7 @@ namespace Tensorflow | |||
| if (y is Tensor tr) | |||
| dtype = tr.dtype.as_base_dtype(); | |||
| var namescope = ops.name_scope(null, name, new { x, y }); | |||
| return tf_with(namescope, scope => | |||
| return tf_with(ops.name_scope(null, name, new { x, y }), scope => | |||
| { | |||
| Tensor result = null; | |||
| var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | |||
| @@ -155,7 +155,6 @@ namespace Tensorflow | |||
| return result; | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,9 +17,16 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Globalization; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using NumSharp.Backends; | |||
| using NumSharp.Backends.Unmanaged; | |||
| using NumSharp.Utilities; | |||
| using Tensorflow.Framework; | |||
| using static Tensorflow.Binding; | |||
| @@ -29,42 +36,68 @@ namespace Tensorflow | |||
| /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | |||
| /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | |||
| /// </summary> | |||
| [SuppressMessage("ReSharper", "ConvertToAutoProperty")] | |||
| public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | |||
| { | |||
| private int _id; | |||
| private Operation _op; | |||
| private readonly int _id; | |||
| private readonly Operation _op; | |||
| private readonly int _value_index; | |||
| private TF_Output? _tf_output; | |||
| private readonly TF_DataType _override_dtype; | |||
| public int Id => _id; | |||
| /// <summary> | |||
| /// The Graph that contains this tensor. | |||
| /// </summary> | |||
| public Graph graph => op?.graph; | |||
| /// <summary> | |||
| /// The Operation that produces this tensor as an output. | |||
| /// </summary> | |||
| public Operation op => _op; | |||
| public Tensor[] outputs => op.outputs; | |||
| /// <summary> | |||
| /// The string name of this tensor. | |||
| /// The string name of this tensor. | |||
| /// </summary> | |||
| public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; | |||
| private int _value_index; | |||
| /// <summary> | |||
| /// The index of this tensor in the outputs of its Operation. | |||
| /// </summary> | |||
| public int value_index => _value_index; | |||
| private TF_DataType _dtype = TF_DataType.DtInvalid; | |||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | |||
| /// <summary> | |||
| /// The DType of elements in this tensor. | |||
| /// </summary> | |||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | |||
| public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | |||
| public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | |||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | |||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | |||
| public int NDims => rank; | |||
| private TF_Output? _tf_output; | |||
| /// <summary> | |||
| /// The name of the device on which this tensor will be produced, or null. | |||
| /// </summary> | |||
| public string Device => op.Device; | |||
| public int[] dims => shape; | |||
| /// <summary> | |||
| /// used for keep other pointer when do implicit operating | |||
| /// Used for keep other pointer when do implicit operating | |||
| /// </summary> | |||
| public object Tag { get; set; } | |||
| /// <summary> | |||
| /// Returns the shape of a tensor. | |||
| /// </summary> | |||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks> | |||
| public int[] shape | |||
| { | |||
| get | |||
| @@ -76,14 +109,13 @@ namespace Tensorflow | |||
| var status = new Status(); | |||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||
| status.Check(); | |||
| } | |||
| else | |||
| } else | |||
| { | |||
| for (int i = 0; i < rank; i++) | |||
| dims[i] = c_api.TF_Dim(_handle, i); | |||
| } | |||
| return dims.Select(x => Convert.ToInt32(x)).ToArray(); | |||
| return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); | |||
| } | |||
| set | |||
| @@ -93,38 +125,52 @@ namespace Tensorflow | |||
| if (value == null) | |||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||
| else | |||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status); | |||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||
| } | |||
| } | |||
| public int[] _shape_tuple() | |||
| { | |||
| if (shape == null) return null; | |||
| return shape.Select(x => (int)x).ToArray(); | |||
| return (int[]) shape.Clone(); | |||
| } | |||
| public TensorShape TensorShape => tensor_util.to_shape(shape); | |||
| public void SetShape(TensorShape shape) | |||
| /// <summary> | |||
| /// Updates the shape of this tensor. | |||
| /// </summary> | |||
| public void set_shape(TensorShape shape) | |||
| { | |||
| this.shape = shape.dims; | |||
| this.shape = (int[]) shape.dims.Clone(); | |||
| } | |||
| /// <summary> | |||
| /// Updates the shape of this tensor. | |||
| /// </summary> | |||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||
| public void SetShape(TensorShape shape) | |||
| { | |||
| this.shape = (int[]) shape.dims.Clone(); | |||
| } | |||
| /// <summary> | |||
| /// Updates the shape of this tensor. | |||
| /// </summary> | |||
| public void set_shape(Tensor shape) | |||
| { | |||
| // ReSharper disable once MergeConditionalExpression | |||
| this.shape = shape is null ? null : shape.shape; | |||
| } | |||
| public int[] dims => shape; | |||
| /// <summary> | |||
| /// number of dimensions | |||
| /// 0 Scalar (magnitude only) | |||
| /// 1 Vector (magnitude and direction) | |||
| /// 2 Matrix (table of numbers) | |||
| /// 3 3-Tensor (cube of numbers) | |||
| /// number of dimensions <br></br> | |||
| /// 0 Scalar (magnitude only) <br></br> | |||
| /// 1 Vector (magnitude and direction) <br></br> | |||
| /// 2 Matrix (table of numbers) <br></br> | |||
| /// 3 3-Tensor (cube of numbers) <br></br> | |||
| /// n n-Tensor (you get the idea) | |||
| /// </summary> | |||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/rank</remarks> | |||
| public int rank | |||
| { | |||
| get | |||
| @@ -137,17 +183,15 @@ namespace Tensorflow | |||
| status.Check(); | |||
| return ndim; | |||
| } | |||
| else | |||
| { | |||
| return c_api.TF_NumDims(_handle); | |||
| } | |||
| return c_api.TF_NumDims(_handle); | |||
| } | |||
| } | |||
| public int NDims => rank; | |||
| public string Device => op.Device; | |||
| /// <summary> | |||
| /// Returns a list of Operations that consume this tensor. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public Operation[] consumers() | |||
| { | |||
| var output = _as_tf_output(); | |||
| @@ -157,38 +201,181 @@ namespace Tensorflow | |||
| public TF_Output _as_tf_output() | |||
| { | |||
| if(!_tf_output.HasValue) | |||
| if (!_tf_output.HasValue) | |||
| _tf_output = new TF_Output(op, value_index); | |||
| return _tf_output.Value; | |||
| } | |||
| public T[] Data<T>() | |||
| [Obsolete("Please use ToArray<T>() instead.", false)] | |||
| public T[] Data<T>() where T : unmanaged | |||
| { | |||
| // Column major order | |||
| // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg | |||
| // matrix:[[1, 2, 3], [4, 5, 6]] | |||
| // index: 0 2 4 1 3 5 | |||
| // result: 1 4 2 5 3 6 | |||
| var data = new T[size]; | |||
| for (ulong i = 0; i < size; i++) | |||
| return ToArray<T>(); | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception> | |||
| public T[] ToArray<T>() where T : unmanaged | |||
| { | |||
| //Are the types matching? | |||
| if (typeof(T).as_dtype() == dtype) | |||
| { | |||
| data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * itemsize)); | |||
| } | |||
| if (NDims == 0 && size == 1) //is it a scalar? | |||
| { | |||
| unsafe | |||
| { | |||
| return new T[] {*(T*) buffer}; | |||
| } | |||
| } | |||
| //types match, no need to perform cast | |||
| var ret = new T[size]; | |||
| unsafe | |||
| { | |||
| var len = (long) size; | |||
| fixed (T* dst = ret) | |||
| { | |||
| //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. | |||
| var src = (T*) buffer; | |||
| len *= ((long) itemsize); | |||
| System.Buffer.MemoryCopy(src, dst, len, len); | |||
| } | |||
| } | |||
| return ret; | |||
| } else | |||
| { | |||
| //types do not match, need to perform cast | |||
| if (NDims == 0 && size == 1) //is it a scalar? | |||
| { | |||
| unsafe | |||
| { | |||
| #if _REGEN | |||
| #region Compute | |||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||
| { | |||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||
| case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)}; | |||
| % | |||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| #endregion | |||
| #else | |||
| #region Compute | |||
| switch (dtype.as_numpy_dtype()?.GetTypeCode()) | |||
| { | |||
| case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)}; | |||
| case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)}; | |||
| case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer, NPTypeCode.Int16)}; | |||
| case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer, NPTypeCode.UInt16)}; | |||
| case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer, NPTypeCode.Int32)}; | |||
| case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer, NPTypeCode.UInt32)}; | |||
| case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer, NPTypeCode.Int64)}; | |||
| case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer, NPTypeCode.UInt64)}; | |||
| case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer, NPTypeCode.Char)}; | |||
| case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer, NPTypeCode.Double)}; | |||
| case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)}; | |||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| #endregion | |||
| #endif | |||
| } | |||
| } | |||
| return data; | |||
| var ret = new T[size]; | |||
| unsafe | |||
| { | |||
| var len = (long) size; | |||
| fixed (T* dstRet = ret) | |||
| { | |||
| T* dst = dstRet; //local stack copy | |||
| #if _REGEN | |||
| #region Compute | |||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||
| { | |||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||
| case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| % | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| #endregion | |||
| #else | |||
| #region Compute | |||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||
| { | |||
| case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T> | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| #endregion | |||
| #endif | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Copies the memory of current buffer onto newly allocated array. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||
| public byte[] Data() | |||
| { | |||
| var data = new byte[bytesize]; | |||
| Marshal.Copy(buffer, data, 0, (int)bytesize); | |||
| return data; | |||
| return BufferToArray(); | |||
| } | |||
| /// <summary> | |||
| /// Copies the memory of current buffer onto newly allocated array. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public byte[] BufferToArray() | |||
| { | |||
| unsafe | |||
| { | |||
| // ReSharper disable once LocalVariableHidesMember | |||
| var bytesize = (long) this.bytesize; | |||
| var data = new byte[bytesize]; | |||
| fixed (byte* dst = data) | |||
| System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); | |||
| return data; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Extracts string array from current Tensor. | |||
| /// </summary> | |||
| /// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception> | |||
| public unsafe string[] StringData() | |||
| { | |||
| if (dtype != TF_DataType.TF_STRING) | |||
| throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); | |||
| // | |||
| // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | |||
| // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | |||
| @@ -199,19 +386,19 @@ namespace Tensorflow | |||
| var buffer = new byte[size][]; | |||
| var src = c_api.TF_TensorData(_handle); | |||
| var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||
| src += (int)(size * 8); | |||
| var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize); | |||
| src += (int) (size * 8); | |||
| for (int i = 0; i < buffer.Length; i++) | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| IntPtr dst = IntPtr.Zero; | |||
| UIntPtr dstLen = UIntPtr.Zero; | |||
| var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); | |||
| var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status); | |||
| status.Check(true); | |||
| buffer[i] = new byte[(int)dstLen]; | |||
| buffer[i] = new byte[(int) dstLen]; | |||
| Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||
| src += (int)read; | |||
| src += (int) read; | |||
| } | |||
| } | |||
| @@ -229,51 +416,29 @@ namespace Tensorflow | |||
| } | |||
| /// <summary> | |||
| /// Evaluates this tensor in a `Session`. | |||
| /// Evaluates this tensor in a `Session`. | |||
| /// </summary> | |||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | |||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||
| /// <returns></returns> | |||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||
| public NDArray eval(params FeedItem[] feed_dict) | |||
| { | |||
| return ops._eval_using_default_session(this, feed_dict, graph); | |||
| } | |||
| public NDArray eval(Session session, FeedItem[] feed_dict = null) | |||
| /// <summary> | |||
| /// Evaluates this tensor in a `Session`. | |||
| /// </summary> | |||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | |||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||
| public NDArray eval(Session session, params FeedItem[] feed_dict) | |||
| { | |||
| return ops._eval_using_default_session(this, feed_dict, graph, session); | |||
| } | |||
| public TF_DataType ToTFDataType(Type type) | |||
| { | |||
| switch (type.Name) | |||
| { | |||
| case "Char": | |||
| return TF_DataType.TF_UINT8; | |||
| case "Int16": | |||
| return TF_DataType.TF_INT16; | |||
| case "Int32": | |||
| return TF_DataType.TF_INT32; | |||
| case "Int64": | |||
| return TF_DataType.TF_INT64; | |||
| case "Single": | |||
| return TF_DataType.TF_FLOAT; | |||
| case "Double": | |||
| return TF_DataType.TF_DOUBLE; | |||
| case "Byte": | |||
| return TF_DataType.TF_UINT8; | |||
| case "String": | |||
| return TF_DataType.TF_STRING; | |||
| case "Boolean": | |||
| return TF_DataType.TF_BOOL; | |||
| default: | |||
| throw new NotImplementedException("ToTFDataType error"); | |||
| } | |||
| } | |||
| public Tensor slice(Slice slice) | |||
| { | |||
| var slice_spec = new int[] { slice.Start.Value }; | |||
| var slice_spec = new int[] {slice.Start.Value}; | |||
| var begin = new List<int>(); | |||
| var end = new List<int>(); | |||
| var strides = new List<int>(); | |||
| @@ -289,26 +454,26 @@ namespace Tensorflow | |||
| if (slice.Stop.HasValue) | |||
| { | |||
| end.Add(slice.Stop.Value); | |||
| } | |||
| else | |||
| } else | |||
| { | |||
| end.Add(0); | |||
| end_mask |= (1 << index); | |||
| } | |||
| strides.Add(slice.Step); | |||
| index += 1; | |||
| } | |||
| return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||
| return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||
| { | |||
| string name = scope; | |||
| if (begin != null) | |||
| { | |||
| var (packed_begin, packed_end, packed_strides) = | |||
| (array_ops.stack(begin.ToArray()), | |||
| array_ops.stack(end.ToArray()), | |||
| array_ops.stack(strides.ToArray())); | |||
| array_ops.stack(end.ToArray()), | |||
| array_ops.stack(strides.ToArray())); | |||
| return gen_array_ops.strided_slice( | |||
| this, | |||
| @@ -320,7 +485,6 @@ namespace Tensorflow | |||
| shrink_axis_mask: shrink_axis_mask, | |||
| new_axis_mask: new_axis_mask, | |||
| ellipsis_mask: ellipsis_mask, | |||
| name: name); | |||
| } | |||
| @@ -330,7 +494,7 @@ namespace Tensorflow | |||
| public Tensor slice(int start) | |||
| { | |||
| var slice_spec = new int[] { start }; | |||
| var slice_spec = new int[] {start}; | |||
| var begin = new List<int>(); | |||
| var end = new List<int>(); | |||
| var strides = new List<int>(); | |||
| @@ -349,15 +513,15 @@ namespace Tensorflow | |||
| index += 1; | |||
| } | |||
| return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||
| return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||
| { | |||
| string name = scope; | |||
| if (begin != null) | |||
| { | |||
| var (packed_begin, packed_end, packed_strides) = | |||
| (array_ops.stack(begin.ToArray()), | |||
| array_ops.stack(end.ToArray()), | |||
| array_ops.stack(strides.ToArray())); | |||
| array_ops.stack(end.ToArray()), | |||
| array_ops.stack(strides.ToArray())); | |||
| return gen_array_ops.strided_slice( | |||
| this, | |||
| @@ -369,7 +533,6 @@ namespace Tensorflow | |||
| shrink_axis_mask: shrink_axis_mask, | |||
| new_axis_mask: new_axis_mask, | |||
| ellipsis_mask: ellipsis_mask, | |||
| name: name); | |||
| } | |||
| @@ -392,29 +555,13 @@ namespace Tensorflow | |||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||
| } | |||
| protected override void DisposeManagedState() | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| c_api.TF_DeleteTensor(handle); | |||
| } | |||
| protected override void DisposeUnManagedState(IntPtr handle) | |||
| { | |||
| if(handle != IntPtr.Zero) | |||
| { | |||
| c_api.TF_DeleteTensor(handle); | |||
| } | |||
| } | |||
| public bool IsDisposed | |||
| { | |||
| get | |||
| { | |||
| lock (this) | |||
| { | |||
| return _handle == IntPtr.Zero; | |||
| } | |||
| } | |||
| } | |||
| public bool IsDisposed => _disposed; | |||
| public int tensor_int_val { get; set; } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,35 +1,84 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Represents the shape of a `Tensor`. | |||
| /// Represents the shape of a `Tensor`. | |||
| /// </summary> | |||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks> | |||
| public class TensorShape | |||
| { | |||
| private Shape shape; | |||
| private readonly Shape shape; | |||
| /// <summary> | |||
| /// Returns a list of Dimensions, or None if the shape is unspecified. | |||
| /// </summary> | |||
| public int[] dims => shape.Dimensions; | |||
| /// <summary> | |||
| /// Returns the rank of this shape. | |||
| /// </summary> | |||
| public int ndim => shape.NDim; | |||
| /// <summary> | |||
| /// Returns the rank of this shape. | |||
| /// </summary> | |||
| public int rank => shape.NDim; | |||
| /// <summary> | |||
| /// Returns the size this shape represents. | |||
| /// </summary> | |||
| public int size => shape.Size; | |||
| public TensorShape(TensorShapeProto proto) | |||
| { | |||
| if (proto.UnknownRank) return; | |||
| switch (proto.Dim.Count) | |||
| { | |||
| case 0: shape = new Shape(new int[0]); break; | |||
| case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; | |||
| case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; | |||
| default: | |||
| var protodims = proto.Dim; | |||
| var len = protodims.Count; | |||
| var dims = new int[len]; | |||
| for (int i = 0; i < len; i++) | |||
| dims[i] = (int) protodims[i].Size; | |||
| shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); | |||
| shape = new Shape(dims); break; | |||
| } | |||
| } | |||
| public TensorShape(params int[] dims) | |||
| { | |||
| shape = new Shape(dims); | |||
| switch (dims.Length) | |||
| { | |||
| case 0: shape = new Shape(new int[0]); break; | |||
| case 1: shape = Shape.Vector((int) dims[0]); break; | |||
| case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | |||
| default: shape = new Shape(dims); break; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="slice"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException">When <see cref="Slice"/> is not an Index.</exception> | |||
| [SuppressMessage("ReSharper", "PossibleInvalidOperationException")] | |||
| public TensorShape this[Slice slice] | |||
| { | |||
| get | |||
| { | |||
| if (slice.Start.HasValue == false || slice.Length.HasValue == false) | |||
| throw new ArgumentException("Slice must has Start and Length."); | |||
| return new TensorShape(dims.Skip(slice.Start.Value) | |||
| .Take(slice.Length.Value) | |||
| .ToArray()); | |||
| @@ -37,7 +86,7 @@ namespace Tensorflow | |||
| } | |||
| /// <summary> | |||
| /// Returns True iff `self` is fully defined in every dimension. | |||
| /// Returns True iff `self` is fully defined in every dimension. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public bool is_fully_defined() | |||
| @@ -50,6 +99,7 @@ namespace Tensorflow | |||
| throw new NotImplementedException("TensorShape is_compatible_with"); | |||
| } | |||
| [SuppressMessage("ReSharper", "ParameterHidesMember")] | |||
| public TensorShape with_rank_at_least(int rank) | |||
| { | |||
| if (rank != ndim) | |||
| @@ -59,35 +109,68 @@ namespace Tensorflow | |||
| } | |||
| /// <summary> | |||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||
| /// </summary> | |||
| /// <param name="other"></param> | |||
| /// <returns></returns> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public TensorShape concatenate(int[] other) | |||
| { | |||
| return concatenate(new TensorShape(other)); | |||
| } | |||
| /// <summary> | |||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||
| /// </summary> | |||
| /// <param name="other"></param> | |||
| /// <returns></returns> | |||
| public TensorShape concatenate(int[] other_) | |||
| public TensorShape concatenate(TensorShape other) | |||
| { | |||
| var other = new TensorShape(other_); | |||
| var otherShape = other; | |||
| if (ndim < 0 || other.ndim < 0) | |||
| if (ndim < 0 || otherShape.ndim < 0) | |||
| return new TensorShape(); | |||
| else | |||
| { | |||
| var concatenate_dims = new int[ndim + other.ndim]; | |||
| var concatenate_dims = new int[ndim + otherShape.ndim]; | |||
| for (int i = 0; i < ndim; i++) | |||
| concatenate_dims[i] = dims[i]; | |||
| for (int i = 0; i < other.ndim; i++) | |||
| concatenate_dims[ndim + i] = other.dims[i]; | |||
| for (int i = 0; i < otherShape.ndim; i++) | |||
| concatenate_dims[ndim + i] = otherShape.dims[i]; | |||
| return new TensorShape(concatenate_dims); | |||
| } | |||
| } | |||
| public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); | |||
| public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); | |||
| public override string ToString() | |||
| { | |||
| return shape.ToString(); | |||
| } | |||
| public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); | |||
| public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | |||
| public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes | |||
| public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | |||
| public static implicit operator int[](TensorShape shape) => shape.dims; | |||
| public static explicit operator int(TensorShape shape) => shape.size; | |||
| public static explicit operator TensorShape(int dim) => new TensorShape(dim); | |||
| public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | |||
| public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | |||
| public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); | |||
| public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | |||
| public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); | |||
| public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | |||
| public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); | |||
| public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); | |||
| public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); | |||
| public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); | |||
| } | |||
| } | |||
| @@ -15,6 +15,8 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Numerics; | |||
| using NumSharp.Backends; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -23,35 +25,100 @@ namespace Tensorflow | |||
| public static TF_DataType int8 = TF_DataType.TF_INT8; | |||
| public static TF_DataType int32 = TF_DataType.TF_INT32; | |||
| public static TF_DataType int64 = TF_DataType.TF_INT64; | |||
| public static TF_DataType uint8 = TF_DataType.TF_UINT8; | |||
| public static TF_DataType uint32 = TF_DataType.TF_UINT32; | |||
| public static TF_DataType uint64 = TF_DataType.TF_UINT64; | |||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | |||
| public static TF_DataType float16 = TF_DataType.TF_HALF; | |||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | |||
| public static Type as_numpy_datatype(this TF_DataType type) | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="type"></param> | |||
| /// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns> | |||
| public static Type as_numpy_dtype(this TF_DataType type) | |||
| { | |||
| switch (type) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| return typeof(bool); | |||
| case TF_DataType.TF_UINT8: | |||
| return typeof(byte); | |||
| case TF_DataType.TF_INT64: | |||
| return typeof(long); | |||
| case TF_DataType.TF_UINT64: | |||
| return typeof(ulong); | |||
| case TF_DataType.TF_INT32: | |||
| return typeof(int); | |||
| case TF_DataType.TF_UINT32: | |||
| return typeof(uint); | |||
| case TF_DataType.TF_INT16: | |||
| return typeof(short); | |||
| case TF_DataType.TF_UINT16: | |||
| return typeof(ushort); | |||
| case TF_DataType.TF_FLOAT: | |||
| return typeof(float); | |||
| case TF_DataType.TF_DOUBLE: | |||
| return typeof(double); | |||
| case TF_DataType.TF_STRING: | |||
| return typeof(string); | |||
| case TF_DataType.TF_COMPLEX128: | |||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||
| return typeof(Complex); | |||
| default: | |||
| return null; | |||
| } | |||
| } | |||
| // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" | |||
| public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="type"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="NPTypeCode"/></exception> | |||
| public static NPTypeCode as_numpy_typecode(this TF_DataType type) | |||
| { | |||
| switch (type) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| return NPTypeCode.Boolean; | |||
| case TF_DataType.TF_UINT8: | |||
| return NPTypeCode.Byte; | |||
| case TF_DataType.TF_INT64: | |||
| return NPTypeCode.Int64; | |||
| case TF_DataType.TF_INT32: | |||
| return NPTypeCode.Int32; | |||
| case TF_DataType.TF_INT16: | |||
| return NPTypeCode.Int16; | |||
| case TF_DataType.TF_UINT64: | |||
| return NPTypeCode.UInt64; | |||
| case TF_DataType.TF_UINT32: | |||
| return NPTypeCode.UInt32; | |||
| case TF_DataType.TF_UINT16: | |||
| return NPTypeCode.UInt16; | |||
| case TF_DataType.TF_FLOAT: | |||
| return NPTypeCode.Single; | |||
| case TF_DataType.TF_DOUBLE: | |||
| return NPTypeCode.Double; | |||
| case TF_DataType.TF_STRING: | |||
| return NPTypeCode.String; | |||
| case TF_DataType.TF_COMPLEX128: | |||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||
| return NPTypeCode.Complex; | |||
| default: | |||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="type"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | |||
| public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null) | |||
| { | |||
| switch (type.Name) | |||
| { | |||
| @@ -98,7 +165,7 @@ namespace Tensorflow | |||
| dtype = TF_DataType.TF_BOOL; | |||
| break; | |||
| default: | |||
| throw new Exception("as_dtype Not Implemented"); | |||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||
| } | |||
| return dtype.Value; | |||
| @@ -106,16 +173,7 @@ namespace Tensorflow | |||
| public static DataType as_datatype_enum(this TF_DataType type) | |||
| { | |||
| DataType dtype = DataType.DtInvalid; | |||
| switch (type) | |||
| { | |||
| default: | |||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||
| break; | |||
| } | |||
| return dtype; | |||
| return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; | |||
| } | |||
| public static TF_DataType as_base_dtype(this TF_DataType type) | |||
| @@ -132,7 +190,7 @@ namespace Tensorflow | |||
| public static Type as_numpy_dtype(this DataType type) | |||
| { | |||
| return type.as_tf_dtype().as_numpy_datatype(); | |||
| return type.as_tf_dtype().as_numpy_dtype(); | |||
| } | |||
| public static DataType as_base_dtype(this DataType type) | |||
| @@ -144,16 +202,7 @@ namespace Tensorflow | |||
| public static TF_DataType as_tf_dtype(this DataType type) | |||
| { | |||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||
| switch (type) | |||
| { | |||
| default: | |||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||
| break; | |||
| } | |||
| return dtype; | |||
| return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; | |||
| } | |||
| public static TF_DataType as_ref(this TF_DataType type) | |||
| @@ -17,6 +17,7 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Linq; | |||
| using NumSharp.Utilities; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -82,6 +83,12 @@ namespace Tensorflow | |||
| throw new NotImplementedException("MakeNdarray"); | |||
| } | |||
| private static readonly TF_DataType[] quantized_types = new TF_DataType[] | |||
| { | |||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||
| TF_DataType.TF_QINT32 | |||
| }; | |||
| /// <summary> | |||
| /// Create a TensorProto. | |||
| /// </summary> | |||
| @@ -98,18 +105,9 @@ namespace Tensorflow | |||
| if (values is TensorProto tp) | |||
| return tp; | |||
| if (dtype != TF_DataType.DtInvalid) | |||
| ; | |||
| bool is_quantized = new TF_DataType[] | |||
| { | |||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||
| TF_DataType.TF_QINT32 | |||
| }.Contains(dtype); | |||
| // We first convert value to a numpy array or scalar. | |||
| NDArray nparray = null; | |||
| var np_dt = dtype.as_numpy_datatype(); | |||
| var np_dt = dtype.as_numpy_dtype(); | |||
| if (values is NDArray nd) | |||
| { | |||
| @@ -188,37 +186,37 @@ namespace Tensorflow | |||
| if (values.GetType().IsArray) | |||
| nparray = np.array((int[])values, np_dt); | |||
| else | |||
| nparray = Convert.ToInt32(values); | |||
| nparray = Converts.ToInt32(values); | |||
| break; | |||
| case "Int64": | |||
| if (values.GetType().IsArray) | |||
| nparray = np.array((int[])values, np_dt); | |||
| else | |||
| nparray = Convert.ToInt64(values); | |||
| nparray = Converts.ToInt64(values); | |||
| break; | |||
| case "Single": | |||
| if (values.GetType().IsArray) | |||
| nparray = np.array((float[])values, np_dt); | |||
| else | |||
| nparray = Convert.ToSingle(values); | |||
| nparray = Converts.ToSingle(values); | |||
| break; | |||
| case "Double": | |||
| if (values.GetType().IsArray) | |||
| nparray = np.array((double[])values, np_dt); | |||
| else | |||
| nparray = Convert.ToDouble(values); | |||
| nparray = Converts.ToDouble(values); | |||
| break; | |||
| case "String": | |||
| if (values.GetType().IsArray) | |||
| nparray = np.array((string[])values, np_dt); | |||
| else | |||
| nparray = NDArray.FromString(Convert.ToString(values)); | |||
| nparray = NDArray.FromString(Converts.ToString(values)); | |||
| break; | |||
| case "Boolean": | |||
| if (values.GetType().IsArray) | |||
| nparray = np.array((bool[])values, np_dt); | |||
| else | |||
| nparray = Convert.ToBoolean(values); | |||
| nparray = Converts.ToBoolean(values); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); | |||
| @@ -226,13 +224,13 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype); | |||
| var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); | |||
| if (numpy_dtype == TF_DataType.DtInvalid) | |||
| throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | |||
| // If dtype was specified and is a quantized type, we convert | |||
| // numpy_dtype back into the quantized version. | |||
| if (is_quantized) | |||
| if (quantized_types.Contains(dtype)) | |||
| numpy_dtype = dtype; | |||
| bool is_same_size = false; | |||
| @@ -0,0 +1,52 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Train | |||
| { | |||
| public class ExponentialMovingAverage | |||
| { | |||
| float _decay; | |||
| int? _num_updates; | |||
| bool _zero_debias; | |||
| string _name; | |||
| public string name => _name; | |||
| List<VariableV1> _averages; | |||
| public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, | |||
| string name = "ExponentialMovingAverage") | |||
| { | |||
| _decay = decay; | |||
| _num_updates = num_updates; | |||
| _zero_debias = zero_debias; | |||
| _name = name; | |||
| _averages = new List<VariableV1>(); | |||
| } | |||
| /// <summary> | |||
| /// Maintains moving averages of variables. | |||
| /// </summary> | |||
| /// <param name="var_list"></param> | |||
| /// <returns></returns> | |||
| public Operation apply(RefVariable[] var_list = null) | |||
| { | |||
| if (var_list == null) | |||
| var_list = variables.trainable_variables() as RefVariable[]; | |||
| foreach(var var in var_list) | |||
| { | |||
| if (!_averages.Contains(var)) | |||
| { | |||
| ops.init_scope(); | |||
| var slot = new SlotCreator(); | |||
| var.initialized_value(); | |||
| // var avg = slot.create_zeros_slot | |||
| } | |||
| } | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -198,7 +198,7 @@ namespace Tensorflow | |||
| if (!tf.context.executing_eagerly()) | |||
| { | |||
| var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||
| var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||
| if (train_op != null && train_op.Contains(apply_updates)) | |||
| train_op.Add(apply_updates); | |||
| } | |||
| @@ -359,7 +359,7 @@ namespace Tensorflow | |||
| var tmp = variables.trainable_variables(); | |||
| var vars = ops.get_collection<RefVariable>(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
| var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
| switch (tmp) | |||
| { | |||
| case List<RefVariable> values: | |||
| @@ -370,7 +370,7 @@ namespace Tensorflow | |||
| break; | |||
| } | |||
| var_list = var_list.Concat(ops.get_collection<RefVariable>(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
| var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
| var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | |||
| var var_refs = processors.Select(x => x.target()).ToArray(); | |||
| @@ -0,0 +1,94 @@ | |||
| using System; | |||
| using System.IO; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using NumSharp.Backends.Unmanaged; | |||
| namespace Tensorflow.Util | |||
| { | |||
| public static class UnmanagedExtensions | |||
| { | |||
| //internally UnmanagedMemoryStream can't construct with null address. | |||
| private static readonly unsafe byte* _empty = (byte*) Marshal.AllocHGlobal(1); | |||
| /// <summary> | |||
| /// Creates a memory stream based on given <paramref name="block"/>. | |||
| /// </summary> | |||
| /// <param name="block">The block to stream. Can be default/null.</param> | |||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block) | |||
| { | |||
| unsafe | |||
| { | |||
| if (block.Address == null) | |||
| return new UnmanagedMemoryStream(_empty, 0); | |||
| return new UnmanagedMemoryStream(block.Address, block.BytesCount); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Creates a memory stream based on given <paramref name="block"/>. | |||
| /// </summary> | |||
| /// <param name="block">The block to stream. Can be default/null.</param> | |||
| /// <param name="offset">Offset from the start of the block.</param> | |||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block, long offset) | |||
| { | |||
| if (block.BytesCount - offset <= 0) | |||
| throw new ArgumentOutOfRangeException(nameof(offset)); | |||
| unsafe | |||
| { | |||
| if (block.Address == null) | |||
| return new UnmanagedMemoryStream(_empty, 0); | |||
| return new UnmanagedMemoryStream(block.Address + offset, block.BytesCount - offset); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Creates a memory stream based on given <paramref name="address"/>. | |||
| /// </summary> | |||
| /// <param name="address">The block to stream. Can be IntPtr.Zero.</param> | |||
| /// <param name="length">The length of the block in bytes.</param> | |||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public static UnmanagedMemoryStream Stream(this IntPtr address, long length) | |||
| { | |||
| if (length <= 0) | |||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||
| unsafe | |||
| { | |||
| if (address == IntPtr.Zero) | |||
| return new UnmanagedMemoryStream(_empty, 0); | |||
| // ReSharper disable once AssignNullToNotNullAttribute | |||
| return new UnmanagedMemoryStream((byte*) address, length); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Creates a memory stream based on given <paramref name="address"/>. | |||
| /// </summary> | |||
| /// <param name="address">The block to stream. Can be IntPtr.Zero.</param> | |||
| /// <param name="offset">Offset from the start of the block.</param> | |||
| /// <param name="length">The length of the block in bytes.</param> | |||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public static UnmanagedMemoryStream Stream(this IntPtr address, long offset, long length) | |||
| { | |||
| if (length <= 0) | |||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||
| unsafe | |||
| { | |||
| if (address == IntPtr.Zero) | |||
| return new UnmanagedMemoryStream(_empty, 0); | |||
| return new UnmanagedMemoryStream((byte*) address + offset, length); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -121,7 +121,7 @@ namespace Tensorflow | |||
| if(collections == null) | |||
| { | |||
| collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES }; | |||
| collections = new List<string> { tf.GraphKeys.GLOBAL_VARIABLES }; | |||
| } | |||
| // Store the graph key so optimizers know how to only retrieve variables from | |||
| @@ -129,8 +129,8 @@ namespace Tensorflow | |||
| _graph_key = ops.get_default_graph().graph_key; | |||
| _trainable = trainable; | |||
| if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) | |||
| collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
| if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | |||
| collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | |||
| ops.init_scope(); | |||
| var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||
| @@ -158,7 +158,7 @@ namespace Tensorflow | |||
| // Or get the initial value from a Tensor or Python object. | |||
| else | |||
| { | |||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); | |||
| var shape = _initial_value.shape; | |||
| dtype = _initial_value.dtype; | |||
| @@ -308,5 +308,28 @@ namespace Tensorflow | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| /// <summary> | |||
| /// Returns the value of this variable, read in the current context. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| private ITensorOrOperation read_value() | |||
| { | |||
| return array_ops.identity(_variable, name: "read"); | |||
| } | |||
| public Tensor is_variable_initialized(RefVariable variable) | |||
| { | |||
| return state_ops.is_variable_initialized(variable); | |||
| } | |||
| public Tensor initialized_value() | |||
| { | |||
| ops.init_scope(); | |||
| throw new NotImplementedException(""); | |||
| /*return control_flow_ops.cond(is_variable_initialized(this), | |||
| read_value, | |||
| () => initial_value);*/ | |||
| } | |||
| } | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Eager; | |||
| @@ -145,5 +146,10 @@ namespace Tensorflow | |||
| var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Tensor is_variable_initialized(RefVariable @ref, string name = null) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -106,5 +106,13 @@ namespace Tensorflow | |||
| throw new NotImplementedException("scatter_add"); | |||
| } | |||
| public static Tensor is_variable_initialized(RefVariable @ref, string name = null) | |||
| { | |||
| if (@ref.dtype.is_ref_dtype()) | |||
| return gen_state_ops.is_variable_initialized(@ref: @ref, name: name); | |||
| throw new NotImplementedException(""); | |||
| //return @ref.is_initialized(name: name); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -28,7 +29,7 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public static object trainable_variables() | |||
| { | |||
| return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
| return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES); | |||
| } | |||
| /// <summary> | |||
| @@ -40,11 +41,11 @@ namespace Tensorflow | |||
| { | |||
| var all = new List<VariableV1>(); | |||
| var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||
| var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||
| if(collection != null) | |||
| all.AddRange(collection as List<VariableV1>); | |||
| collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); | |||
| collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); | |||
| if (collection != null) | |||
| all.AddRange(collection as List<VariableV1>); | |||
| @@ -64,7 +65,7 @@ namespace Tensorflow | |||
| /// <returns>A list of `Variable` objects.</returns> | |||
| public static List<VariableV1> global_variables(string scope = null) | |||
| { | |||
| var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||
| var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||
| return result == null ? new List<VariableV1>() : result as List<VariableV1>; | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| %all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] | |||
| %all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] | |||
| %supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] | |||
| %supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] | |||
| %supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||
| %supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||
| %supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
| %supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] | |||
| %supported_numericals_TF_DataType = ["TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] | |||
| %supported_numericals_TF_DataType_full = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||
| //this is the type we use in summerizing/reducting: | |||
| %supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||
| %supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
| %supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"] | |||
| %supported_numericals_signed_lowercase = ["short","int","long","double","float"] | |||
| %supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"] | |||
| %supported_numericals_signed_onevales = ["1","1","1L","1d","1f"] | |||
| %supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"] | |||
| %supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"] | |||
| %supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"] | |||
| %supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] | |||
| %supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||
| %supported_dtypes_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] | |||
| %supported_dtypes_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||
| %supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||
| %supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
| %supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"] | |||
| %supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"] | |||
| //this is the type we use in summerizing/reducting: | |||
| %supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||
| %supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
| @@ -27,57 +27,113 @@ namespace Tensorflow | |||
| /// specified, but it is also possible to pass an explicit list of | |||
| /// variables. | |||
| /// </summary> | |||
| public static class GraphKeys | |||
| public class GraphKeys | |||
| { | |||
| #region const | |||
| /// <summary> | |||
| /// the subset of `Variable` objects that will be trained by an optimizer. | |||
| /// </summary> | |||
| public const string TRAINABLE_VARIABLES_ = "trainable_variables"; | |||
| /// <summary> | |||
| /// Trainable resource-style variables. | |||
| /// </summary> | |||
| public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; | |||
| /// <summary> | |||
| /// Key for streaming model ports. | |||
| /// </summary> | |||
| public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; | |||
| /// <summary> | |||
| /// Key to collect losses | |||
| /// </summary> | |||
| public const string LOSSES_ = "losses"; | |||
| /// <summary> | |||
| /// Key to collect Variable objects that are global (shared across machines). | |||
| /// Default collection for all variables, except local ones. | |||
| /// </summary> | |||
| public const string GLOBAL_VARIABLES_ = "variables"; | |||
| public const string TRAIN_OP_ = "train_op"; | |||
| public const string GLOBAL_STEP_ = "global_step"; | |||
| public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; | |||
| /// <summary> | |||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
| /// </summary> | |||
| public const string SAVEABLE_OBJECTS_ = "saveable_objects"; | |||
| /// <summary> | |||
| /// Key to collect update_ops | |||
| /// </summary> | |||
| public const string UPDATE_OPS_ = "update_ops"; | |||
| // Key to collect summaries. | |||
| public const string SUMMARIES_ = "summaries"; | |||
| // Used to store v2 summary names. | |||
| public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; | |||
| // Key for control flow context. | |||
| public const string COND_CONTEXT_ = "cond_context"; | |||
| public const string WHILE_CONTEXT_ = "while_context"; | |||
| #endregion | |||
| /// <summary> | |||
| /// the subset of `Variable` objects that will be trained by an optimizer. | |||
| /// </summary> | |||
| public static string TRAINABLE_VARIABLES = "trainable_variables"; | |||
| public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; | |||
| /// <summary> | |||
| /// Trainable resource-style variables. | |||
| /// </summary> | |||
| public static string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||
| public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; | |||
| /// <summary> | |||
| /// Key for streaming model ports. | |||
| /// </summary> | |||
| public static string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||
| public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; | |||
| /// <summary> | |||
| /// Key to collect losses | |||
| /// </summary> | |||
| public const string LOSSES = "losses"; | |||
| public string LOSSES => LOSSES_; | |||
| /// <summary> | |||
| /// Key to collect Variable objects that are global (shared across machines). | |||
| /// Default collection for all variables, except local ones. | |||
| /// </summary> | |||
| public static string GLOBAL_VARIABLES = "variables"; | |||
| public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; | |||
| public static string TRAIN_OP = "train_op"; | |||
| public string TRAIN_OP => TRAIN_OP_; | |||
| public static string GLOBAL_STEP = GLOBAL_STEP = "global_step"; | |||
| public string GLOBAL_STEP => GLOBAL_STEP_; | |||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||
| public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; | |||
| /// <summary> | |||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
| /// </summary> | |||
| public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||
| public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; | |||
| /// <summary> | |||
| /// Key to collect update_ops | |||
| /// </summary> | |||
| public static string UPDATE_OPS = "update_ops"; | |||
| public string UPDATE_OPS => UPDATE_OPS_; | |||
| // Key to collect summaries. | |||
| public const string SUMMARIES = "summaries"; | |||
| public string SUMMARIES => SUMMARIES_; | |||
| // Used to store v2 summary names. | |||
| public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||
| public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; | |||
| // Key for control flow context. | |||
| public static string COND_CONTEXT = "cond_context"; | |||
| public static string WHILE_CONTEXT = "while_context"; | |||
| public string COND_CONTEXT => COND_CONTEXT_; | |||
| public string WHILE_CONTEXT => WHILE_CONTEXT_; | |||
| } | |||
| } | |||
| } | |||
| @@ -230,8 +230,8 @@ namespace Tensorflow | |||
| // Add attrs | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| var bytes = attr.Value.ToByteArray(); | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
| uint len = (uint)bytes.Length; | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||
| @@ -488,6 +488,8 @@ namespace Tensorflow | |||
| switch (value) | |||
| { | |||
| case String str: | |||
| return constant_op.constant(str, dtype: TF_DataType.TF_STRING, name: name); | |||
| case NDArray nd: | |||
| return constant_op.constant(nd, dtype: dtype, name: name); | |||
| case Tensor tensor: | |||
| @@ -64,13 +64,12 @@ namespace Tensorflow | |||
| public Session Session() | |||
| { | |||
| defaultSession = new Session(); | |||
| return defaultSession; | |||
| return new Session(); | |||
| } | |||
| public Session Session(Graph graph) | |||
| public Session Session(Graph graph, SessionOptions opts = null) | |||
| { | |||
| return new Session(graph); | |||
| return new Session(graph, opts: opts); | |||
| } | |||
| public Session Session(SessionOptions opts) | |||
| @@ -9,24 +9,18 @@ namespace TensorFlowBenchmark | |||
| { | |||
| static void Main(string[] args) | |||
| { | |||
| #if DEBUG | |||
| IConfig config = new DebugInProcessConfig(); | |||
| #else | |||
| IConfig config = null; | |||
| #endif | |||
| if (args?.Length > 0) | |||
| { | |||
| for (int i = 0; i < args.Length; i++) | |||
| { | |||
| string name = $"TensorFlowBenchmark.{args[i]}"; | |||
| var type = Type.GetType(name); | |||
| BenchmarkRunner.Run(type, config); | |||
| BenchmarkRunner.Run(type); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, config); | |||
| BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator)); | |||
| } | |||
| Console.ReadLine(); | |||
| @@ -6,6 +6,7 @@ | |||
| <NoWin32Manifest>true</NoWin32Manifest> | |||
| <AssemblyName>TensorFlowBenchmark</AssemblyName> | |||
| <RootNamespace>TensorFlowBenchmark</RootNamespace> | |||
| <LangVersion>7.3</LangVersion> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
| @@ -0,0 +1,76 @@ | |||
| using System; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using BenchmarkDotNet.Attributes; | |||
| using Google.Protobuf.WellKnownTypes; | |||
| using NumSharp; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowBenchmark.Unmanaged | |||
| { | |||
| public struct UnmanagedStruct | |||
| { | |||
| public int a; | |||
| public long b; | |||
| public UnmanagedStruct(int _) | |||
| { | |||
| a = 2; | |||
| b = 3; | |||
| } | |||
| } | |||
| [SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)] | |||
| [MinColumn, MaxColumn, MeanColumn, MedianColumn] | |||
| public unsafe class StructCastBenchmark | |||
| { | |||
| private static void EnsureIsUnmanaged<T>(T _) where T : unmanaged | |||
| { } | |||
| static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile. | |||
| => EnsureIsUnmanaged(new UnmanagedStruct()); | |||
| private IntPtr data; | |||
| private void* dataptr; | |||
| [GlobalSetup] | |||
| public void Setup() | |||
| { | |||
| data = Marshal.AllocHGlobal(Marshal.SizeOf<UnmanagedStruct>()); | |||
| dataptr = data.ToPointer(); | |||
| } | |||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||
| public void Marshal_PtrToStructure() | |||
| { | |||
| UnmanagedStruct _; | |||
| for (int i = 0; i < 10000; i++) | |||
| { | |||
| _ = Marshal.PtrToStructure<UnmanagedStruct>(data); | |||
| } | |||
| } | |||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||
| public void PointerCast() | |||
| { | |||
| var dptr = dataptr; | |||
| UnmanagedStruct _; | |||
| for (int i = 0; i < 10000; i++) | |||
| { | |||
| _ = *(UnmanagedStruct*) dptr; | |||
| } | |||
| } | |||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||
| public void Unsafe_Read() | |||
| { | |||
| var dptr = dataptr; | |||
| UnmanagedStruct _; | |||
| for (int i = 0; i < 10000; i++) | |||
| { | |||
| _ = Unsafe.Read<UnmanagedStruct>(dptr); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -80,26 +80,18 @@ namespace TensorFlowNET.Examples | |||
| for (int epoch = 0; epoch < training_epochs; epoch++) | |||
| { | |||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | |||
| { | |||
| sess.run(optimizer, | |||
| new FeedItem(X, x), | |||
| new FeedItem(Y, y)); | |||
| } | |||
| sess.run(optimizer, (X, x), (Y, y)); | |||
| // Display logs per epoch step | |||
| if ((epoch + 1) % display_step == 0) | |||
| { | |||
| var c = sess.run(cost, | |||
| new FeedItem(X, train_X), | |||
| new FeedItem(Y, train_Y)); | |||
| var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | |||
| } | |||
| } | |||
| Console.WriteLine("Optimization Finished!"); | |||
| var training_cost = sess.run(cost, | |||
| new FeedItem(X, train_X), | |||
| new FeedItem(Y, train_Y)); | |||
| var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||
| // Testing example | |||
| @@ -107,8 +99,7 @@ namespace TensorFlowNET.Examples | |||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | |||
| new FeedItem(X, test_X), | |||
| new FeedItem(Y, test_Y)); | |||
| (X, test_X), (Y, test_Y)); | |||
| Console.WriteLine($"Testing cost={testing_cost}"); | |||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||
| @@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples | |||
| // Display logs per epoch step | |||
| if ((epoch + 1) % display_step == 0) | |||
| print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); | |||
| print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); | |||
| sw.Reset(); | |||
| } | |||
| @@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples | |||
| var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | |||
| // Calculate accuracy | |||
| var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | |||
| float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||
| print($"Accuracy: {acc.ToString("F4")}"); | |||
| float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||
| print($"Accuracy: {acc:F4}"); | |||
| return acc > 0.9; | |||
| } | |||
| @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples | |||
| public void PrepareData() | |||
| { | |||
| // get model file | |||
| string url = "http://download.tf.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; | |||
| string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; | |||
| Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); | |||
| Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); | |||
| @@ -21,6 +21,7 @@ using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Threading.Tasks; | |||
| using Tensorflow; | |||
| using TensorFlowNET.Examples.Utility; | |||
| using static Tensorflow.Binding; | |||
| @@ -381,10 +382,15 @@ namespace TensorFlowNET.Examples | |||
| Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) | |||
| { | |||
| int how_many_bottlenecks = 0; | |||
| foreach (var (label_name, label_lists) in image_lists) | |||
| var kvs = image_lists.ToArray(); | |||
| var categories = new string[] {"training", "testing", "validation"}; | |||
| Parallel.For(0, kvs.Length, i => | |||
| { | |||
| foreach (var category in new string[] { "training", "testing", "validation" }) | |||
| var (label_name, label_lists) = kvs[i]; | |||
| Parallel.For(0, categories.Length, j => | |||
| { | |||
| var category = categories[j]; | |||
| var category_list = label_lists[category]; | |||
| foreach (var (index, unused_base_name) in enumerate(category_list)) | |||
| { | |||
| @@ -395,8 +401,8 @@ namespace TensorFlowNET.Examples | |||
| if (how_many_bottlenecks % 300 == 0) | |||
| print($"{how_many_bottlenecks} bottleneck files created."); | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| }); | |||
| } | |||
| private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | |||
| @@ -508,7 +514,7 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| // get a set of images to teach the network about the new classes | |||
| string fileName = "flower_photos.tgz"; | |||
| string url = $"http://download.tf.org/example_images/{fileName}"; | |||
| string url = $"http://download.tensorflow.org/example_images/{fileName}"; | |||
| Web.Download(url, data_dir, fileName); | |||
| Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); | |||
| @@ -0,0 +1,54 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| public class Dataset | |||
| { | |||
| string annot_path; | |||
| int[] input_sizes; | |||
| int batch_size; | |||
| bool data_aug; | |||
| int[] train_input_sizes; | |||
| NDArray strides; | |||
| NDArray anchors; | |||
| Dictionary<int, string> classes; | |||
| int num_classes; | |||
| int anchor_per_scale; | |||
| int max_bbox_per_scale; | |||
| string[] annotations; | |||
| int num_samples; | |||
| int batch_count; | |||
| public int Length = 0; | |||
| public Dataset(string dataset_type, Config cfg) | |||
| { | |||
| annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH; | |||
| input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE; | |||
| batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE; | |||
| data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG; | |||
| train_input_sizes = cfg.TRAIN.INPUT_SIZE; | |||
| strides = np.array(cfg.YOLO.STRIDES); | |||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||
| num_classes = classes.Count; | |||
| anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS)); | |||
| anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; | |||
| max_bbox_per_scale = 150; | |||
| annotations = load_annotations(); | |||
| num_samples = len(annotations); | |||
| batch_count = 0; | |||
| } | |||
| string[] load_annotations() | |||
| { | |||
| return File.ReadAllLines(annot_path); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,139 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| /// <summary> | |||
| /// Implementation of YOLO v3 object detector in Tensorflow | |||
| /// https://github.com/YunYang1994/tensorflow-yolov3 | |||
| /// </summary> | |||
| public class Main : IExample | |||
| { | |||
| public bool Enabled { get; set; } = true; | |||
| public bool IsImportingGraph { get; set; } = false; | |||
| public string Name => "YOLOv3"; | |||
| #region args | |||
| Dictionary<int, string> classes; | |||
| int num_classes; | |||
| float learn_rate_init; | |||
| float learn_rate_end; | |||
| int first_stage_epochs; | |||
| int second_stage_epochs; | |||
| int warmup_periods; | |||
| string time; | |||
| float moving_ave_decay; | |||
| int max_bbox_per_scale; | |||
| int steps_per_period; | |||
| Dataset trainset, testset; | |||
| Config cfg; | |||
| Tensor input_data; | |||
| Tensor label_sbbox; | |||
| Tensor label_mbbox; | |||
| Tensor label_lbbox; | |||
| Tensor true_sbboxes; | |||
| Tensor true_mbboxes; | |||
| Tensor true_lbboxes; | |||
| Tensor trainable; | |||
| Session sess; | |||
| YOLOv3 model; | |||
| #endregion | |||
| public bool Run() | |||
| { | |||
| PrepareData(); | |||
| var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | |||
| var options = new SessionOptions(); | |||
| options.SetConfig(new ConfigProto { AllowSoftPlacement = true }); | |||
| using (var sess = tf.Session(graph, opts: options)) | |||
| { | |||
| Train(sess); | |||
| } | |||
| return true; | |||
| } | |||
| public void Train(Session sess) | |||
| { | |||
| } | |||
| public void Test(Session sess) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public Graph BuildGraph() | |||
| { | |||
| var graph = new Graph().as_default(); | |||
| tf_with(tf.name_scope("define_input"), scope => | |||
| { | |||
| input_data = tf.placeholder(dtype: tf.float32, name: "input_data"); | |||
| label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox"); | |||
| label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox"); | |||
| label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox"); | |||
| true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes"); | |||
| true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes"); | |||
| true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes"); | |||
| trainable = tf.placeholder(dtype: tf.@bool, name: "training"); | |||
| }); | |||
| tf_with(tf.name_scope("define_loss"), scope => | |||
| { | |||
| model = new YOLOv3(cfg, input_data, trainable); | |||
| }); | |||
| tf_with(tf.name_scope("define_weight_decay"), scope => | |||
| { | |||
| var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables()); | |||
| }); | |||
| return graph; | |||
| } | |||
| public Graph ImportGraph() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public void Predict(Session sess) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public void PrepareData() | |||
| { | |||
| cfg = new Config(Name); | |||
| string dataDir = Path.Combine(Name, "data"); | |||
| Directory.CreateDirectory(dataDir); | |||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||
| num_classes = classes.Count; | |||
| learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT; | |||
| learn_rate_end = cfg.TRAIN.LEARN_RATE_END; | |||
| first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS; | |||
| second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS; | |||
| warmup_periods = cfg.TRAIN.WARMUP_EPOCHS; | |||
| DateTime now = DateTime.Now; | |||
| time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}"; | |||
| moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY; | |||
| max_bbox_per_scale = 150; | |||
| trainset = new Dataset("train", cfg); | |||
| testset = new Dataset("test", cfg); | |||
| steps_per_period = trainset.Length; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,27 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| class Utils | |||
| { | |||
| public static Dictionary<int, string> read_class_names(string file) | |||
| { | |||
| var classes = new Dictionary<int, string>(); | |||
| foreach (var line in File.ReadAllLines(file)) | |||
| classes[classes.Count] = line; | |||
| return classes; | |||
| } | |||
| public static NDArray get_anchors(string file) | |||
| { | |||
| return np.array(File.ReadAllText(file).Split(',') | |||
| .Select(x => float.Parse(x)) | |||
| .ToArray()).reshape(3, 3, 2); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,65 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| public class YOLOv3 | |||
| { | |||
| Config cfg; | |||
| Tensor trainable; | |||
| Tensor input_data; | |||
| Dictionary<int, string> classes; | |||
| int num_class; | |||
| NDArray strides; | |||
| NDArray anchors; | |||
| int anchor_per_scale; | |||
| float iou_loss_thresh; | |||
| string upsample_method; | |||
| Tensor conv_lbbox; | |||
| Tensor conv_mbbox; | |||
| Tensor conv_sbbox; | |||
| public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_) | |||
| { | |||
| cfg = cfg_; | |||
| input_data = input_data_; | |||
| trainable = trainable_; | |||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||
| num_class = len(classes); | |||
| strides = np.array(cfg.YOLO.STRIDES); | |||
| anchors = Utils.get_anchors(cfg.YOLO.ANCHORS); | |||
| anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; | |||
| iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH; | |||
| upsample_method = cfg.YOLO.UPSAMPLE_METHOD; | |||
| (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data); | |||
| tf_with(tf.variable_scope("pred_sbbox"), scope => | |||
| { | |||
| // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); | |||
| }); | |||
| tf_with(tf.variable_scope("pred_mbbox"), scope => | |||
| { | |||
| // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); | |||
| }); | |||
| tf_with(tf.variable_scope("pred_lbbox"), scope => | |||
| { | |||
| // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); | |||
| }); | |||
| } | |||
| private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data) | |||
| { | |||
| Tensor route_1, route_2; | |||
| (route_1, route_2, input_data) = backbone.darknet53(input_data, trainable); | |||
| return (conv_lbbox, conv_mbbox, conv_sbbox); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,28 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| class backbone | |||
| { | |||
| public static (Tensor, Tensor, Tensor) darknet53(Tensor input_data, Tensor trainable) | |||
| { | |||
| return tf_with(tf.variable_scope("darknet"), scope => | |||
| { | |||
| input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 3, 32 }, trainable: trainable, name: "conv0"); | |||
| input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 32, 64 }, trainable: trainable, name: "conv1", downsample: true); | |||
| foreach (var i in range(1)) | |||
| input_data = common.residual_block(input_data, 64, 32, 64, trainable: trainable, name: $"residual{i + 0}"); | |||
| var route_1 = input_data; | |||
| var route_2 = input_data; | |||
| return (route_1, route_2, input_data); | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,72 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| class common | |||
| { | |||
| public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable, | |||
| string name, bool downsample = false, bool activate = true, | |||
| bool bn = true) | |||
| { | |||
| return tf_with(tf.variable_scope(name), scope => | |||
| { | |||
| int[] strides; | |||
| string padding; | |||
| if (downsample) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| else | |||
| { | |||
| strides = new int[] { 1, 1, 1, 1 }; | |||
| padding = "SAME"; | |||
| } | |||
| var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true, | |||
| shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f)); | |||
| var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding); | |||
| if (bn) | |||
| { | |||
| conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer, | |||
| gamma_initializer: tf.ones_initializer, | |||
| moving_mean_initializer: tf.zeros_initializer, | |||
| moving_variance_initializer: tf.ones_initializer, training: trainable); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| if (activate) | |||
| conv = tf.nn.leaky_relu(conv, alpha: 0.1f); | |||
| return conv; | |||
| }); | |||
| } | |||
| public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1, | |||
| int filter_num2, Tensor trainable, string name) | |||
| { | |||
| var short_cut = input_data; | |||
| return tf_with(tf.variable_scope(name), scope => | |||
| { | |||
| input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 }, | |||
| trainable: trainable, name: "conv1"); | |||
| input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 }, | |||
| trainable: trainable, name: "conv2"); | |||
| var residual_output = input_data + short_cut; | |||
| return residual_output; | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,94 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| public class Config | |||
| { | |||
| public YoloConfig YOLO; | |||
| public TrainConfig TRAIN; | |||
| public TestConfig TEST; | |||
| public Config(string root) | |||
| { | |||
| YOLO = new YoloConfig(root); | |||
| TRAIN = new TrainConfig(root); | |||
| TEST = new TestConfig(root); | |||
| } | |||
| public class YoloConfig | |||
| { | |||
| string _root; | |||
| public string CLASSES; | |||
| public string ANCHORS; | |||
| public float MOVING_AVE_DECAY = 0.9995f; | |||
| public int[] STRIDES = new int[] { 8, 16, 32 }; | |||
| public int ANCHOR_PER_SCALE = 3; | |||
| public float IOU_LOSS_THRESH = 0.5f; | |||
| public string UPSAMPLE_METHOD = "resize"; | |||
| public string ORIGINAL_WEIGHT; | |||
| public string DEMO_WEIGHT; | |||
| public YoloConfig(string root) | |||
| { | |||
| _root = root; | |||
| CLASSES = Path.Combine(_root, "data", "classes", "coco.names"); | |||
| ANCHORS = Path.Combine(_root, "data", "anchors", "basline_anchors.txt"); | |||
| ORIGINAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco.ckpt"); | |||
| DEMO_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); | |||
| } | |||
| } | |||
| public class TrainConfig | |||
| { | |||
| string _root; | |||
| public int BATCH_SIZE = 6; | |||
| public int[] INPUT_SIZE = new int[] { 320, 352, 384, 416, 448, 480, 512, 544, 576, 608 }; | |||
| public bool DATA_AUG = true; | |||
| public float LEARN_RATE_INIT = 1e-4f; | |||
| public float LEARN_RATE_END = 1e-6f; | |||
| public int WARMUP_EPOCHS = 2; | |||
| public int FISRT_STAGE_EPOCHS = 20; | |||
| public int SECOND_STAGE_EPOCHS = 30; | |||
| public string INITIAL_WEIGHT; | |||
| public string ANNOT_PATH; | |||
| public TrainConfig(string root) | |||
| { | |||
| _root = root; | |||
| INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt"); | |||
| ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); | |||
| } | |||
| } | |||
| public class TestConfig | |||
| { | |||
| string _root; | |||
| public int BATCH_SIZE = 2; | |||
| public int[] INPUT_SIZE = new int[] { 544 }; | |||
| public bool DATA_AUG = false; | |||
| public bool WRITE_IMAGE = true; | |||
| public string WRITE_IMAGE_PATH; | |||
| public string WEIGHT_FILE; | |||
| public bool WRITE_IMAGE_SHOW_LABEL = true; | |||
| public bool SHOW_LABEL = true; | |||
| public int SECOND_STAGE_EPOCHS = 30; | |||
| public float SCORE_THRESHOLD = 0.3f; | |||
| public float IOU_THRESHOLD = 0.45f; | |||
| public string ANNOT_PATH; | |||
| public TestConfig(string root) | |||
| { | |||
| _root = root; | |||
| ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_test.txt"); | |||
| WRITE_IMAGE_PATH = Path.Combine(_root, "data", "detection"); | |||
| WEIGHT_FILE = Path.Combine(_root, "checkpoint", "yolov3_test_loss=9.2099.ckpt-5"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -6,6 +6,14 @@ | |||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| <OutputPath>bin\debug-gpu</OutputPath> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
| <OutputPath>bin\release-gpu</OutputPath> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="Colorful.Console" Version="1.2.9" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||