| @@ -1,6 +1,6 @@ | |||||
|  |  | ||||
| **TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||||
| **TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||||
| [](https://gitter.im/sci-sharp/community) | [](https://gitter.im/sci-sharp/community) | ||||
| [](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | [](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | ||||
| @@ -16,7 +16,7 @@ TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). <a | |||||
| ### Why TensorFlow.NET ? | ### Why TensorFlow.NET ? | ||||
| `SciSharp STACK`'s mission is to bring popular data science technology into the .NET world and to provide .NET developers with a powerful Machine Learning tool set without reinventing the wheel. Scince the APIs are kept as similar as possible you can immediately adapt any existing Tensorflow code in C# with a zero learning curve. Take a look at a comparison picture and see how comfortably a Tensorflow/Python script translates into a C# program with TensorFlow.NET. | |||||
| `SciSharp STACK`'s mission is to bring popular data science technology into the .NET world and to provide .NET developers with a powerful Machine Learning tool set without reinventing the wheel. Since the APIs are kept as similar as possible you can immediately adapt any existing Tensorflow code in C# with a zero learning curve. Take a look at a comparison picture and see how comfortably a Tensorflow/Python script translates into a C# program with TensorFlow.NET. | |||||
|  |  | ||||
| @@ -34,40 +34,15 @@ PM> Install-Package TensorFlow.NET | |||||
| ### Install tensorflow binary | ### Install tensorflow binary | ||||
| ### For CPU version | ### For CPU version | ||||
| PM> Install-Package SciSharp.TensorFlow.Redist | PM> Install-Package SciSharp.TensorFlow.Redist | ||||
| ### For GPU version (CUDA and cuDNN are required) | ### For GPU version (CUDA and cuDNN are required) | ||||
| PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | ||||
| ``` | ``` | ||||
| Import TF.NET. | |||||
| ```cs | |||||
| using Tensorflow; | |||||
| ``` | |||||
| Add two constants: | |||||
| ```cs | |||||
| // Create a Constant op | |||||
| var a = tf.constant(4.0f); | |||||
| var b = tf.constant(5.0f); | |||||
| var c = tf.add(a, b); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var o = sess.run(c); | |||||
| } | |||||
| ``` | |||||
| Import TF.NET in your project. | |||||
| Feed placeholder: | |||||
| ```cs | ```cs | ||||
| // Create a placeholder op | |||||
| var a = tf.placeholder(tf.float32); | |||||
| var b = tf.placeholder(tf.float32); | |||||
| var c = tf.add(a, b); | |||||
| using(var sess = tf.Session()) | |||||
| { | |||||
| var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); | |||||
| } | |||||
| using static Tensorflow.Binding; | |||||
| ``` | ``` | ||||
| Linear Regression: | Linear Regression: | ||||
| @@ -91,39 +66,40 @@ var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||||
| var init = tf.global_variables_initializer(); | var init = tf.global_variables_initializer(); | ||||
| // Start training | // Start training | ||||
| with(tf.Session(), sess => | |||||
| using(tf.Session()) | |||||
| { | { | ||||
| // Run the initializer | // Run the initializer | ||||
| sess.run(init); | sess.run(init); | ||||
| // Fit all training data | // Fit all training data | ||||
| for (int epoch = 0; epoch < training_epochs; epoch++) | for (int epoch = 0; epoch < training_epochs; epoch++) | ||||
| { | { | ||||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | foreach (var (x, y) in zip<float>(train_X, train_Y)) | ||||
| sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y)); | |||||
| sess.run(optimizer, (X, x), (Y, y)); | |||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| if ((epoch + 1) % display_step == 0) | if ((epoch + 1) % display_step == 0) | ||||
| { | { | ||||
| var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); | |||||
| var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | ||||
| } | } | ||||
| Console.WriteLine("Optimization Finished!"); | |||||
| var training_cost = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); | |||||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||||
| // Testing example | |||||
| var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), new FeedItem(Y, test_Y)); | |||||
| Console.WriteLine($"Testing cost={testing_cost}"); | |||||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||||
| } | } | ||||
| Console.WriteLine("Optimization Finished!"); | |||||
| var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||||
| // Testing example | |||||
| var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | |||||
| (X, test_X), (Y, test_Y)); | |||||
| Console.WriteLine($"Testing cost={testing_cost}"); | |||||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||||
| return diff < 0.01; | |||||
| }); | }); | ||||
| ``` | ``` | ||||
| @@ -17,6 +17,6 @@ | |||||
| <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="NumSharp" Version="0.20.0-alpha1" /> | |||||
| <PackageReference Include="NumSharp" Version="0.20.0" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| </Project> | </Project> | ||||
| @@ -59,6 +59,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_Version(); | |||||
| public static extern IntPtr TF_Version(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -14,11 +14,16 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using static Tensorflow.ops; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public graph_util_impl graph_util => new graph_util_impl(); | public graph_util_impl graph_util => new graph_util_impl(); | ||||
| public GraphKeys GraphKeys { get; } = new GraphKeys(); | |||||
| public Graph get_default_graph() | public Graph get_default_graph() | ||||
| { | { | ||||
| return ops.get_default_graph(); | return ops.get_default_graph(); | ||||
| @@ -0,0 +1,61 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System.Collections.Generic; | |||||
| using Tensorflow.IO; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class tensorflow | |||||
| { | |||||
| public image_internal image = new image_internal(); | |||||
| public class image_internal | |||||
| { | |||||
| public Tensor decode_jpeg(Tensor contents, | |||||
| int channels = 0, | |||||
| int ratio = 1, | |||||
| bool fancy_upscaling = true, | |||||
| bool try_recover_truncated = false, | |||||
| float acceptable_fraction = 1, | |||||
| string dct_method = "", | |||||
| string name = null) | |||||
| => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||||
| fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||||
| acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||||
| public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) | |||||
| => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name); | |||||
| public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) | |||||
| => gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name); | |||||
| public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, | |||||
| string name = null, bool expand_animations = true) | |||||
| => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, | |||||
| name: name, expand_animations: expand_animations); | |||||
| /// <summary> | |||||
| /// Convenience function to check if the 'contents' encodes a JPEG image. | |||||
| /// </summary> | |||||
| /// <param name="contents"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor is_jpeg(Tensor contents, string name = null) | |||||
| => image_ops_impl.is_jpeg(contents, name: name); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -52,5 +52,13 @@ namespace Tensorflow | |||||
| stddev: stddev, | stddev: stddev, | ||||
| seed: seed, | seed: seed, | ||||
| dtype: dtype); | dtype: dtype); | ||||
| public IInitializer random_normal_initializer(float mean = 0.0f, | |||||
| float stddev = 1.0f, | |||||
| int? seed = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid) => new RandomNormal(mean: mean, | |||||
| stddev: stddev, | |||||
| seed: seed, | |||||
| dtype: dtype); | |||||
| } | } | ||||
| } | } | ||||
| @@ -24,8 +24,6 @@ namespace Tensorflow | |||||
| public GFile gfile = new GFile(); | public GFile gfile = new GFile(); | ||||
| public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | ||||
| public gen_image_ops image => new gen_image_ops(); | |||||
| public void import_graph_def(GraphDef graph_def, | public void import_graph_def(GraphDef graph_def, | ||||
| Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
| string[] return_elements = null, | string[] return_elements = null, | ||||
| @@ -148,6 +148,24 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Local Response Normalization. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="depth_radius"></param> | |||||
| /// <param name="bias"></param> | |||||
| /// <param name="alpha"></param> | |||||
| /// <param name="beta"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor lrn(Tensor input, int depth_radius = 5, int bias = 1, | |||||
| int alpha = 1, float beta = 0.5f, string name = null) | |||||
| => gen_nn_ops.local_response_normalization(input, depth_radius: depth_radius, bias: bias, | |||||
| alpha: alpha, beta: beta, name: name); | |||||
| public Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||||
| => nn_ops.leaky_relu(features, alpha: alpha, name: name); | |||||
| public rnn_cell_impl rnn_cell => new rnn_cell_impl(); | public rnn_cell_impl rnn_cell => new rnn_cell_impl(); | ||||
| public Tensor softmax(Tensor logits, int axis = -1, string name = null) | public Tensor softmax(Tensor logits, int axis = -1, string name = null) | ||||
| @@ -33,5 +33,13 @@ namespace Tensorflow | |||||
| /// <returns>The scope name.</returns> | /// <returns>The scope name.</returns> | ||||
| public ops.NameScope name_scope(string name, string default_name = "", object values = null) | public ops.NameScope name_scope(string name, string default_name = "", object values = null) | ||||
| => new ops.NameScope(name, default_name, values); | => new ops.NameScope(name, default_name, values); | ||||
| /// <summary> | |||||
| /// Does nothing. Only useful as a placeholder for control edges. | |||||
| /// </summary> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor no_op(string name = null) | |||||
| => gen_control_flow_ops.no_op(name: name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,32 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System.Collections.Generic; | |||||
| using Tensorflow.IO; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class tensorflow | |||||
| { | |||||
| public strings_internal strings = new strings_internal(); | |||||
| public class strings_internal | |||||
| { | |||||
| public Tensor substr(Tensor input, int pos, int len, | |||||
| string name = null, string @uint = "BYTE") | |||||
| => string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -31,6 +31,9 @@ namespace Tensorflow | |||||
| public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") | public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") | ||||
| => new AdamOptimizer(learning_rate, name: name); | => new AdamOptimizer(learning_rate, name: name); | ||||
| public ExponentialMovingAverage ExponentialMovingAverage(float decay) | |||||
| => new ExponentialMovingAverage(decay); | |||||
| public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); | public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); | ||||
| public string write_graph(Graph graph, string logdir, string name, bool as_text = true) | public string write_graph(Graph graph, string logdir, string name, bool as_text = true) | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -22,7 +23,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public VariableV1[] global_variables(string scope = null) | public VariableV1[] global_variables(string scope = null) | ||||
| { | { | ||||
| return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||||
| return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||||
| .ToArray(); | .ToArray(); | ||||
| } | } | ||||
| @@ -32,6 +33,14 @@ namespace Tensorflow | |||||
| return variables.variables_initializer(g.ToArray()); | return variables.variables_initializer(g.ToArray()); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns all variables created with `trainable=True`. | |||||
| /// </summary> | |||||
| /// <param name="scope"></param> | |||||
| /// <returns></returns> | |||||
| public VariableV1[] trainable_variables(string scope = null) | |||||
| => (variables.trainable_variables() as List<VariableV1>).ToArray(); | |||||
| public RefVariable get_variable(string name, | public RefVariable get_variable(string name, | ||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| @@ -0,0 +1,4 @@ | |||||
| using System.Runtime.CompilerServices; | |||||
| #if DEBUG | |||||
| [assembly: InternalsVisibleTo("TensorFlowNET.UnitTest, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] | |||||
| #endif | |||||
| @@ -178,13 +178,18 @@ namespace Tensorflow | |||||
| public static IEnumerable<(TKey, TValue)> enumerate<TKey, TValue>(KeyValuePair<TKey, TValue>[] values) | public static IEnumerable<(TKey, TValue)> enumerate<TKey, TValue>(KeyValuePair<TKey, TValue>[] values) | ||||
| { | { | ||||
| foreach (var item in values) | |||||
| var len = values.Length; | |||||
| for (var i = 0; i < len; i++) | |||||
| { | |||||
| var item = values[i]; | |||||
| yield return (item.Key, item.Value); | yield return (item.Key, item.Value); | ||||
| } | |||||
| } | } | ||||
| public static IEnumerable<(int, T)> enumerate<T>(IList<T> values) | public static IEnumerable<(int, T)> enumerate<T>(IList<T> values) | ||||
| { | { | ||||
| for (int i = 0; i < values.Count; i++) | |||||
| var len = values.Count; | |||||
| for (int i = 0; i < len; i++) | |||||
| yield return (i, values[i]); | yield return (i, values[i]); | ||||
| } | } | ||||
| @@ -308,15 +313,14 @@ namespace Tensorflow | |||||
| public static IEnumerable TupleToEnumerable(object tuple) | public static IEnumerable TupleToEnumerable(object tuple) | ||||
| { | { | ||||
| Type t = tuple.GetType(); | Type t = tuple.GetType(); | ||||
| if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||||
| if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||||
| { | { | ||||
| var flds = t.GetFields(); | var flds = t.GetFields(); | ||||
| for(int i = 0; i < flds.Length;i++) | |||||
| for (int i = 0; i < flds.Length; i++) | |||||
| { | { | ||||
| yield return flds[i].GetValue(tuple); | yield return flds[i].GetValue(tuple); | ||||
| } | } | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| throw new System.Exception("Expected Tuple."); | throw new System.Exception("Expected Tuple."); | ||||
| } | } | ||||
| @@ -329,12 +333,9 @@ namespace Tensorflow | |||||
| public static bool isinstance(object Item1, object tuple) | public static bool isinstance(object Item1, object tuple) | ||||
| { | { | ||||
| var tup = TupleToEnumerable(tuple); | |||||
| foreach(var t in tup) | |||||
| { | |||||
| if(isinstance(Item1, (Type)t)) | |||||
| foreach (var t in TupleToEnumerable(tuple)) | |||||
| if (isinstance(Item1, (Type) t)) | |||||
| return true; | return true; | ||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -15,58 +15,116 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using NumSharp.Backends.Unmanaged; | |||||
| using static Tensorflow.c_api; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Represents a TF_Buffer that can be passed to Tensorflow. | |||||
| /// </summary> | |||||
| public class Buffer : DisposableObject | public class Buffer : DisposableObject | ||||
| { | { | ||||
| private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||||
| private unsafe TF_Buffer buffer | |||||
| { | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| get => *bufferptr; | |||||
| } | |||||
| private unsafe TF_Buffer* bufferptr | |||||
| { | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| get => (TF_Buffer*) _handle; | |||||
| } | |||||
| public byte[] Data | |||||
| /// <summary> | |||||
| /// The memory block representing this buffer. | |||||
| /// </summary> | |||||
| /// <remarks>The deallocator is set to null.</remarks> | |||||
| public UnmanagedMemoryBlock<byte> MemoryBlock | |||||
| { | { | ||||
| get | |||||
| get | |||||
| { | { | ||||
| var data = new byte[buffer.length]; | |||||
| if (data.Length > 0) | |||||
| Marshal.Copy(buffer.data, data, 0, data.Length); | |||||
| return data; | |||||
| unsafe | |||||
| { | |||||
| EnsureNotDisposed(); | |||||
| var buff = (TF_Buffer*) _handle; | |||||
| return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| public int Length => (int)buffer.length; | |||||
| public Buffer() | |||||
| /// <summary> | |||||
| /// The bytes length of this buffer. | |||||
| /// </summary> | |||||
| public ulong Length | |||||
| { | { | ||||
| _handle = c_api.TF_NewBuffer(); | |||||
| get | |||||
| { | |||||
| EnsureNotDisposed(); | |||||
| return buffer.length; | |||||
| } | |||||
| } | } | ||||
| public Buffer(IntPtr handle) | |||||
| public Buffer() => _handle = TF_NewBuffer(); | |||||
| internal Buffer(IntPtr handle) | |||||
| { | { | ||||
| if (handle == IntPtr.Zero) | |||||
| throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); | |||||
| _handle = handle; | _handle = handle; | ||||
| } | } | ||||
| public Buffer(byte[] data) | |||||
| { | |||||
| var dst = Marshal.AllocHGlobal(data.Length); | |||||
| Marshal.Copy(data, 0, dst, data.Length); | |||||
| public Buffer(byte[] data) : this(_toBuffer(data)) | |||||
| { } | |||||
| _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); | |||||
| private static IntPtr _toBuffer(byte[] data) | |||||
| { | |||||
| if (data == null) | |||||
| throw new ArgumentNullException(nameof(data)); | |||||
| Marshal.FreeHGlobal(dst); | |||||
| unsafe | |||||
| { | |||||
| fixed (byte* src = data) | |||||
| return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); | |||||
| } | |||||
| } | } | ||||
| public static implicit operator IntPtr(Buffer buffer) | public static implicit operator IntPtr(Buffer buffer) | ||||
| { | { | ||||
| buffer.EnsureNotDisposed(); | |||||
| return buffer._handle; | return buffer._handle; | ||||
| } | } | ||||
| public static implicit operator byte[](Buffer buffer) | |||||
| public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. | |||||
| /// <summary> | |||||
| /// Copies this buffer's contents onto a <see cref="byte"/> array. | |||||
| /// </summary> | |||||
| public byte[] ToArray() | |||||
| { | { | ||||
| return buffer.Data; | |||||
| EnsureNotDisposed(); | |||||
| unsafe | |||||
| { | |||||
| var len = buffer.length; | |||||
| if (len == 0) | |||||
| return Array.Empty<byte>(); | |||||
| byte[] data = new byte[len]; | |||||
| fixed (byte* dst = data) | |||||
| System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); | |||||
| return data; | |||||
| } | |||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| => c_api.TF_DeleteBuffer(handle); | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | |||||
| TF_DeleteBuffer(handle); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -16,6 +16,8 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -26,52 +28,71 @@ namespace Tensorflow | |||||
| public abstract class DisposableObject : IDisposable | public abstract class DisposableObject : IDisposable | ||||
| { | { | ||||
| protected IntPtr _handle; | protected IntPtr _handle; | ||||
| protected bool _disposed; | |||||
| protected DisposableObject() { } | |||||
| [SuppressMessage("ReSharper", "UnusedMember.Global")] | |||||
| protected DisposableObject() | |||||
| { } | |||||
| public DisposableObject(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| protected DisposableObject(IntPtr handle) | |||||
| => _handle = handle; | |||||
| protected virtual void DisposeManagedState() | |||||
| [SuppressMessage("ReSharper", "InvertIf")] | |||||
| private void internal_dispose(bool disposing) | |||||
| { | { | ||||
| } | |||||
| if (_disposed) | |||||
| return; | |||||
| protected abstract void DisposeUnManagedState(IntPtr handle); | |||||
| _disposed = true; | |||||
| protected virtual void Dispose(bool disposing) | |||||
| { | |||||
| //first handle managed, they might use the unmanaged resources. | |||||
| if (disposing) | if (disposing) | ||||
| { | |||||
| // free unmanaged resources (unmanaged objects) and override a finalizer below. | |||||
| if (_handle != IntPtr.Zero) | |||||
| { | |||||
| // dispose managed state (managed objects). | |||||
| DisposeManagedState(); | |||||
| // set large fields to null. | |||||
| DisposeUnManagedState(_handle); | |||||
| // dispose managed state (managed objects). | |||||
| DisposeManagedResources(); | |||||
| _handle = IntPtr.Zero; | |||||
| } | |||||
| //free unmanaged memory | |||||
| if (_handle != IntPtr.Zero) | |||||
| { | |||||
| DisposeUnmanagedResources(_handle); | |||||
| _handle = IntPtr.Zero; | |||||
| } | } | ||||
| } | } | ||||
| // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. | |||||
| /// <summary> | |||||
| /// Dispose any managed resources. | |||||
| /// </summary> | |||||
| /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||||
| protected virtual void DisposeManagedResources() | |||||
| { } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||||
| ~DisposableObject() | ~DisposableObject() | ||||
| { | { | ||||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||||
| Dispose(false); | |||||
| internal_dispose(false); | |||||
| } | } | ||||
| // This code added to correctly implement the disposable pattern. | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||||
| Dispose(true); | |||||
| // uncomment the following line if the finalizer is overridden above. | |||||
| GC.SuppressFinalize(this); | |||||
| lock(this) | |||||
| { | |||||
| internal_dispose(true); | |||||
| GC.SuppressFinalize(this); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// If <see cref="_handle"/> is <see cref="IntPtr.Zero"/> then throws <see cref="ObjectDisposedException"/> | |||||
| /// </summary> | |||||
| /// <exception cref="ObjectDisposedException">When <see cref="_handle"/> is <see cref="IntPtr.Zero"/></exception> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| protected void EnsureNotDisposed() | |||||
| { | |||||
| if (_disposed) | |||||
| throw new ObjectDisposedException($"Unable to access disposed object, Type: {GetType().Name}"); | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -2,12 +2,10 @@ | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class Context : IDisposable | |||||
| public class Context : DisposableObject | |||||
| { | { | ||||
| private IntPtr _handle; | |||||
| public static int GRAPH_MODE = 0; | |||||
| public static int EAGER_MODE = 1; | |||||
| public const int GRAPH_MODE = 0; | |||||
| public const int EAGER_MODE = 1; | |||||
| public int default_execution_mode; | public int default_execution_mode; | ||||
| @@ -17,19 +15,16 @@ namespace Tensorflow.Eager | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| public void Dispose() | |||||
| { | |||||
| c_api.TFE_DeleteContext(_handle); | |||||
| } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TFE_DeleteContext(_handle); | |||||
| public bool executing_eagerly() | |||||
| { | |||||
| return false; | |||||
| } | |||||
| public static implicit operator IntPtr(Context ctx) | |||||
| { | |||||
| return ctx._handle; | |||||
| } | |||||
| public bool executing_eagerly() => false; | |||||
| public static implicit operator IntPtr(Context ctx) | |||||
| => ctx._handle; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,24 +1,22 @@ | |||||
| using System; | using System; | ||||
| using System.IO; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class ContextOptions : IDisposable | |||||
| public class ContextOptions : DisposableObject | |||||
| { | { | ||||
| private IntPtr _handle; | |||||
| public ContextOptions() : base(c_api.TFE_NewContextOptions()) | |||||
| { } | |||||
| public ContextOptions() | |||||
| { | |||||
| _handle = c_api.TFE_NewContextOptions(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TFE_DeleteContextOptions(_handle); | |||||
| public void Dispose() | |||||
| { | |||||
| c_api.TFE_DeleteContextOptions(_handle); | |||||
| } | |||||
| public static implicit operator IntPtr(ContextOptions opts) | |||||
| { | |||||
| return opts._handle; | |||||
| } | |||||
| public static implicit operator IntPtr(ContextOptions opts) | |||||
| => opts._handle; | |||||
| } | } | ||||
| } | } | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class KeyError : Exception | |||||
| public class KeyError : TensorflowException | |||||
| { | { | ||||
| public KeyError() : base() | public KeyError() : base() | ||||
| { | { | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class RuntimeError : Exception | |||||
| public class RuntimeError : TensorflowException | |||||
| { | { | ||||
| public RuntimeError() : base() | public RuntimeError() : base() | ||||
| { | { | ||||
| @@ -0,0 +1,36 @@ | |||||
| using System; | |||||
| using System.Runtime.Serialization; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Serves as a base class to all exceptions of Tensorflow.NET. | |||||
| /// </summary> | |||||
| [Serializable] | |||||
| public class TensorflowException : Exception | |||||
| { | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class.</summary> | |||||
| public TensorflowException() | |||||
| { } | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with serialized data.</summary> | |||||
| /// <param name="info">The <see cref="T:System.Runtime.Serialization.SerializationInfo"></see> that holds the serialized object data about the exception being thrown.</param> | |||||
| /// <param name="context">The <see cref="T:System.Runtime.Serialization.StreamingContext"></see> that contains contextual information about the source or destination.</param> | |||||
| /// <exception cref="T:System.ArgumentNullException">The <paramref name="info">info</paramref> parameter is null.</exception> | |||||
| /// <exception cref="T:System.Runtime.Serialization.SerializationException">The class name is null or <see cref="P:System.Exception.HResult"></see> is zero (0).</exception> | |||||
| protected TensorflowException(SerializationInfo info, StreamingContext context) : base(info, context) | |||||
| { } | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message.</summary> | |||||
| /// <param name="message">The message that describes the error.</param> | |||||
| public TensorflowException(string message) : base(message) | |||||
| { } | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message and a reference to the inner exception that is the cause of this exception.</summary> | |||||
| /// <param name="message">The error message that explains the reason for the exception.</param> | |||||
| /// <param name="innerException">The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified.</param> | |||||
| public TensorflowException(string message, Exception innerException) : base(message, innerException) | |||||
| { } | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class TypeError : Exception | |||||
| public class TypeError : TensorflowException | |||||
| { | { | ||||
| public TypeError() : base() | public TypeError() : base() | ||||
| { | { | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class ValueError : Exception | |||||
| public class ValueError : TensorflowException | |||||
| { | { | ||||
| public ValueError() : base() | public ValueError() : base() | ||||
| { | { | ||||
| @@ -6,10 +6,5 @@ | |||||
| { | { | ||||
| } | } | ||||
| ~ScopedTFImportGraphDefOptions() | |||||
| { | |||||
| base.Dispose(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -13,10 +13,5 @@ namespace Tensorflow.Framework.Models | |||||
| { | { | ||||
| } | } | ||||
| ~ScopedTFImportGraphDefResults() | |||||
| { | |||||
| base.Dispose(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -5,10 +5,5 @@ | |||||
| public ScopedTFStatus() : base() | public ScopedTFStatus() : base() | ||||
| { | { | ||||
| } | } | ||||
| ~ScopedTFStatus() | |||||
| { | |||||
| base.Dispose(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -95,7 +95,7 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| case KindOneofCase.BytesList: | case KindOneofCase.BytesList: | ||||
| //var proto_type = ops.get_collection_proto_type(key) | //var proto_type = ops.get_collection_proto_type(key) | ||||
| if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) | |||||
| if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) | |||||
| { | { | ||||
| foreach (var value in col.Value.BytesList.Value) | foreach (var value in col.Value.BytesList.Value) | ||||
| { | { | ||||
| @@ -146,7 +146,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| var variables = graph.get_collection<VariableV1>(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
| var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | |||||
| scope: scope_to_prepend_to_names); | scope: scope_to_prepend_to_names); | ||||
| var var_list = new Dictionary<string, VariableV1>(); | var var_list = new Dictionary<string, VariableV1>(); | ||||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | ||||
| @@ -180,7 +180,7 @@ namespace Tensorflow | |||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| var var_list = new Dictionary<string, RefVariable>(); | var var_list = new Dictionary<string, RefVariable>(); | ||||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||||
| var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||||
| if (variables != null) | if (variables != null) | ||||
| { | { | ||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -27,12 +29,12 @@ namespace Tensorflow | |||||
| if(_registered_ops == null) | if(_registered_ops == null) | ||||
| { | { | ||||
| _registered_ops = new Dictionary<string, OpDef>(); | _registered_ops = new Dictionary<string, OpDef>(); | ||||
| var handle = c_api.TF_GetAllOpList(); | |||||
| var buffer = new Buffer(handle); | |||||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||||
| foreach (var op_def in op_list.Op) | |||||
| _registered_ops[op_def.Name] = op_def; | |||||
| using (var buffer = new Buffer(c_api.TF_GetAllOpList())) | |||||
| { | |||||
| var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| foreach (var op_def in op_list.Op) | |||||
| _registered_ops[op_def.Name] = op_def; | |||||
| } | |||||
| } | } | ||||
| return _registered_ops; | return _registered_ops; | ||||
| @@ -14,49 +14,62 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class DefaultGraphStack | |||||
| /// <summary> | |||||
| /// Serves as a stack for determining current default graph. | |||||
| /// </summary> | |||||
| public class DefaultGraphStack | |||||
| { | { | ||||
| List<StackModel> stack = new List<StackModel>(); | |||||
| private readonly List<StackModel> _stack = new List<StackModel>(); | |||||
| public void set_controller(Graph @default) | public void set_controller(Graph @default) | ||||
| { | { | ||||
| if (!stack.Exists(x => x.Graph == @default)) | |||||
| stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||||
| if (!_stack.Exists(x => x.Graph == @default)) | |||||
| _stack.Add(new StackModel {Graph = @default, IsDefault = true}); | |||||
| foreach (var s in stack) | |||||
| foreach (var s in _stack) | |||||
| s.IsDefault = s.Graph == @default; | s.IsDefault = s.Graph == @default; | ||||
| } | } | ||||
| public Graph get_controller() | public Graph get_controller() | ||||
| { | { | ||||
| if (stack.Count(x => x.IsDefault) == 0) | |||||
| stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||||
| if (_stack.Count(x => x.IsDefault) == 0) | |||||
| _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); | |||||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||||
| { | |||||
| var x = _stack[i]; | |||||
| if (x.IsDefault) | |||||
| return x.Graph; | |||||
| } | |||||
| return stack.Last(x => x.IsDefault).Graph; | |||||
| throw new TensorflowException("Unable to find a default graph"); | |||||
| } | } | ||||
| public bool remove(Graph g) | public bool remove(Graph g) | ||||
| { | { | ||||
| var sm = stack.FirstOrDefault(x => x.Graph == g); | |||||
| if (sm == null) return false; | |||||
| return stack.Remove(sm); | |||||
| if (_stack.Count == 0) | |||||
| return false; | |||||
| var sm = _stack.Find(model => model.Graph == g); | |||||
| return sm != null && _stack.Remove(sm); | |||||
| } | } | ||||
| public void reset() | public void reset() | ||||
| { | { | ||||
| stack.Clear(); | |||||
| _stack.Clear(); | |||||
| } | } | ||||
| } | |||||
| public class StackModel | |||||
| { | |||||
| public Graph Graph { get; set; } | |||||
| public bool IsDefault { get; set; } | |||||
| private class StackModel | |||||
| { | |||||
| public Graph Graph { get; set; } | |||||
| public bool IsDefault { get; set; } | |||||
| } | |||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| @@ -66,8 +67,9 @@ namespace Tensorflow | |||||
| /// within the context should have control dependencies on | /// within the context should have control dependencies on | ||||
| /// `control_inputs`. | /// `control_inputs`. | ||||
| /// </summary> | /// </summary> | ||||
| [SuppressMessage("ReSharper", "CoVariantArrayConversion")] | |||||
| public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | ||||
| => control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray()); | |||||
| => control_dependencies((object[])control_inputs); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns a context manager that specifies control dependencies. | /// Returns a context manager that specifies control dependencies. | ||||
| @@ -14,6 +14,9 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.IO; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class Graph | public partial class Graph | ||||
| @@ -23,21 +26,19 @@ namespace Tensorflow | |||||
| var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | c_api.TF_GraphToGraphDef(_handle, buffer, s); | ||||
| s.Check(true); | s.Check(true); | ||||
| // var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| // buffer.Dispose(); | |||||
| return buffer; | return buffer; | ||||
| } | } | ||||
| private GraphDef _as_graph_def(bool add_shapes = false) | private GraphDef _as_graph_def(bool add_shapes = false) | ||||
| { | { | ||||
| var status = new Status(); | |||||
| var buffer = ToGraphDef(status); | |||||
| status.Check(true); | |||||
| status.Dispose(); | |||||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| GraphDef def; | |||||
| using (var status = new Status()) | |||||
| using (var buffer = ToGraphDef(status)) | |||||
| { | |||||
| status.Check(true); | |||||
| def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| // Strip the experimental library field iff it's empty. | // Strip the experimental library field iff it's empty. | ||||
| // if(def.Library.Function.Count == 0) | // if(def.Library.Function.Count == 0) | ||||
| @@ -45,7 +46,7 @@ namespace Tensorflow | |||||
| return def; | return def; | ||||
| } | } | ||||
| public GraphDef as_graph_def(bool add_shapes = false) | |||||
| public GraphDef as_graph_def(bool add_shapes = false) | |||||
| => _as_graph_def(add_shapes); | => _as_graph_def(add_shapes); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -30,11 +30,10 @@ namespace Tensorflow | |||||
| var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | ||||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | ||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| { | |||||
| var handle = return_output_handle + i * size; | |||||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||||
| } | |||||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| return_outputs[i] = *(tf_output_ptr + i); | |||||
| Marshal.FreeHGlobal(return_output_handle); | Marshal.FreeHGlobal(return_output_handle); | ||||
| @@ -18,6 +18,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -30,7 +31,7 @@ namespace Tensorflow | |||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | ||||
| return OpDef.Parser.ParseFrom(buffer.Data); | |||||
| return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -39,16 +40,20 @@ namespace Tensorflow | |||||
| return c_api.TF_NewOperation(_handle, opType, opName); | return c_api.TF_NewOperation(_handle, opType, opName); | ||||
| } | } | ||||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||||
| public Operation[] ReturnOperations(IntPtr results) | |||||
| { | { | ||||
| TF_Operation return_oper_handle = new TF_Operation(); | TF_Operation return_oper_handle = new TF_Operation(); | ||||
| int num_return_opers = 0; | int num_return_opers = 0; | ||||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | ||||
| Operation[] return_opers = new Operation[num_return_opers]; | Operation[] return_opers = new Operation[num_return_opers]; | ||||
| var tf_op_size = Marshal.SizeOf<TF_Operation>(); | |||||
| for (int i = 0; i < num_return_opers; i++) | for (int i = 0; i < num_return_opers; i++) | ||||
| { | { | ||||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||||
| unsafe | |||||
| { | |||||
| var handle = return_oper_handle.node + tf_op_size * i; | |||||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||||
| } | |||||
| } | } | ||||
| return return_opers; | return return_opers; | ||||
| @@ -67,7 +72,7 @@ namespace Tensorflow | |||||
| public ITensorOrOperation[] get_operations() | public ITensorOrOperation[] get_operations() | ||||
| { | { | ||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||||
| return _nodes_by_name.Values.ToArray(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -81,7 +86,7 @@ namespace Tensorflow | |||||
| public ITensorOrOperation _get_operation_by_name_unsafe(string name) | public ITensorOrOperation _get_operation_by_name_unsafe(string name) | ||||
| { | { | ||||
| return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; | |||||
| return _nodes_by_name.TryGetValue(name, out var val) ? val : null; | |||||
| } | } | ||||
| public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | ||||
| @@ -23,57 +23,58 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /* | |||||
| A TensorFlow computation, represented as a dataflow graph. | |||||
| A `Graph` contains a set of | |||||
| `tf.Operation` objects, | |||||
| which represent units of computation; and | |||||
| `tf.Tensor` objects, which represent | |||||
| the units of data that flow between operations. | |||||
| A default `Graph` is always registered, and accessible by calling | |||||
| `tf.get_default_graph`. | |||||
| To add an operation to the default graph, simply call one of the functions | |||||
| that defines a new `Operation`: | |||||
| ```python | |||||
| c = tf.constant(4.0) | |||||
| assert c.graph is tf.get_default_graph() | |||||
| ``` | |||||
| Another typical usage involves the | |||||
| `tf.Graph.as_default` | |||||
| context manager, which overrides the current default graph for the | |||||
| lifetime of the context: | |||||
| ```python | |||||
| g = tf.Graph() | |||||
| with g.as_default(): | |||||
| # Define operations and tensors in `g`. | |||||
| c = tf.constant(30.0) | |||||
| assert c.graph is g | |||||
| ``` | |||||
| Important note: This class *is not* thread-safe for graph construction. All | |||||
| operations should be created from a single thread, or external | |||||
| synchronization must be provided. Unless otherwise specified, all methods | |||||
| are not thread-safe. | |||||
| A `Graph` instance supports an arbitrary number of "collections" | |||||
| that are identified by name. For convenience when building a large | |||||
| graph, collections can store groups of related objects: for | |||||
| example, the `tf.Variable` uses a collection (named | |||||
| `tf.GraphKeys.GLOBAL_VARIABLES`) for | |||||
| all variables that are created during the construction of a graph. The caller | |||||
| may define additional collections by specifying a new name. | |||||
| */ | |||||
| /// <summary> | /// <summary> | ||||
| /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. | |||||
| /// This leads to a low-level programming model in which you first define the dataflow graph, | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||||
| /// https://www.tensorflow.org/guide/graphs | |||||
| /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. | |||||
| /// This leads to a low-level programming model in which you first define the dataflow graph, | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||||
| /// </summary> | /// </summary> | ||||
| /* | |||||
| A TensorFlow computation, represented as a dataflow graph. | |||||
| A `Graph` contains a set of | |||||
| `tf.Operation` objects, | |||||
| which represent units of computation; and | |||||
| `tf.Tensor` objects, which represent | |||||
| the units of data that flow between operations. | |||||
| A default `Graph` is always registered, and accessible by calling | |||||
| `tf.get_default_graph`. | |||||
| To add an operation to the default graph, simply call one of the functions | |||||
| that defines a new `Operation`: | |||||
| ```python | |||||
| c = tf.constant(4.0) | |||||
| assert c.graph is tf.get_default_graph() | |||||
| ``` | |||||
| Another typical usage involves the | |||||
| `tf.Graph.as_default` | |||||
| context manager, which overrides the current default graph for the | |||||
| lifetime of the context: | |||||
| ```python | |||||
| g = tf.Graph() | |||||
| with g.as_default(): | |||||
| # Define operations and tensors in `g`. | |||||
| c = tf.constant(30.0) | |||||
| assert c.graph is g | |||||
| ``` | |||||
| Important note: This class *is not* thread-safe for graph construction. All | |||||
| operations should be created from a single thread, or external | |||||
| synchronization must be provided. Unless otherwise specified, all methods | |||||
| are not thread-safe. | |||||
| A `Graph` instance supports an arbitrary number of "collections" | |||||
| that are identified by name. For convenience when building a large | |||||
| graph, collections can store groups of related objects: for | |||||
| example, the `tf.Variable` uses a collection (named | |||||
| `tf.GraphKeys.GLOBAL_VARIABLES`) for | |||||
| all variables that are created during the construction of a graph. The caller | |||||
| may define additional collections by specifying a new name. | |||||
| */ | |||||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | |||||
| public partial class Graph : DisposableObject, IEnumerable<Operation> | public partial class Graph : DisposableObject, IEnumerable<Operation> | ||||
| { | { | ||||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | private Dictionary<int, ITensorOrOperation> _nodes_by_id; | ||||
| @@ -368,7 +369,7 @@ namespace Tensorflow | |||||
| var name_key = name.ToLower(); | var name_key = name.ToLower(); | ||||
| int i = 0; | int i = 0; | ||||
| if (_names_in_use.ContainsKey(name_key)) | if (_names_in_use.ContainsKey(name_key)) | ||||
| i = _names_in_use[name_key]; | |||||
| i = _names_in_use[name_key]; | |||||
| // Increment the number for "name_key". | // Increment the number for "name_key". | ||||
| if (mark_as_used) | if (mark_as_used) | ||||
| _names_in_use[name_key] = i + 1; | _names_in_use[name_key] = i + 1; | ||||
| @@ -398,13 +399,13 @@ namespace Tensorflow | |||||
| int num_return_outputs = 0; | int num_return_outputs = 0; | ||||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | ||||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | ||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| unsafe | |||||
| { | { | ||||
| var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i); | |||||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| return_outputs[i] = *(tf_output_ptr + i); | |||||
| return return_outputs; | |||||
| } | } | ||||
| return return_outputs; | |||||
| } | } | ||||
| public string[] get_all_collection_keys() | public string[] get_all_collection_keys() | ||||
| @@ -439,12 +440,12 @@ namespace Tensorflow | |||||
| _unfetchable_ops.Add(op); | _unfetchable_ops.Add(op); | ||||
| } | } | ||||
| protected override void DisposeManagedState() | |||||
| protected override void DisposeManagedResources() | |||||
| { | { | ||||
| ops.default_graph_stack.remove(this); | ops.default_graph_stack.remove(this); | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | { | ||||
| c_api.TF_DeleteGraph(handle); | c_api.TF_DeleteGraph(handle); | ||||
| } | } | ||||
| @@ -496,11 +497,9 @@ namespace Tensorflow | |||||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | ||||
| => GetEnumerable().GetEnumerator(); | => GetEnumerable().GetEnumerator(); | ||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| => throw new NotImplementedException(); | |||||
| public static implicit operator IntPtr(Graph graph) | public static implicit operator IntPtr(Graph graph) | ||||
| { | { | ||||
| return graph._handle; | return graph._handle; | ||||
| @@ -20,7 +20,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class ImportGraphDefOptions : DisposableObject | public class ImportGraphDefOptions : DisposableObject | ||||
| { | { | ||||
| public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||||
| public int NumReturnOutputs | |||||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||||
| public ImportGraphDefOptions() | public ImportGraphDefOptions() | ||||
| { | { | ||||
| @@ -37,7 +38,7 @@ namespace Tensorflow | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteImportGraphDefOptions(handle); | => c_api.TF_DeleteImportGraphDefOptions(handle); | ||||
| public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | ||||
| @@ -16,6 +16,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | |||||
| namespace Tensorflow.IO | namespace Tensorflow.IO | ||||
| { | { | ||||
| @@ -28,6 +29,9 @@ namespace Tensorflow.IO | |||||
| /// <param name="in_order">Traverse in order if True, post order if False.</param> | /// <param name="in_order">Traverse in order if True, post order if False.</param> | ||||
| public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) | public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) | ||||
| { | { | ||||
| if (!Directory.Exists(top)) | |||||
| return Enumerable.Empty<(string, string[], string[])>(); | |||||
| return walk_v2(top, in_order); | return walk_v2(top, in_order); | ||||
| } | } | ||||
| @@ -81,7 +81,7 @@ namespace Tensorflow.Layers | |||||
| // Update global default collections. | // Update global default collections. | ||||
| _add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS }); | |||||
| _add_elements_to_collection(_updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| @@ -152,9 +152,9 @@ namespace Tensorflow.Operations | |||||
| public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | ||||
| { | { | ||||
| // Add the subgraph defined by fn() to the graph. | // Add the subgraph defined by fn() to the graph. | ||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | var original_result = fn(); | ||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| //TODO: port this chunck of missing code: | //TODO: port this chunck of missing code: | ||||
| /* | /* | ||||
| @@ -191,9 +191,9 @@ namespace Tensorflow.Operations | |||||
| public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | ||||
| { | { | ||||
| // Add the subgraph defined by fn() to the graph. | // Add the subgraph defined by fn() to the graph. | ||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | var original_result = fn(); | ||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| switch (original_result) | switch (original_result) | ||||
| { | { | ||||
| @@ -141,7 +141,7 @@ namespace Tensorflow.Operations | |||||
| data, frame_name, is_constant, parallel_iterations, name: name); | data, frame_name, is_constant, parallel_iterations, name: name); | ||||
| if (use_input_shape) | if (use_input_shape) | ||||
| result.SetShape(data.TensorShape); | |||||
| result.set_shape(data.TensorShape); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -195,7 +195,7 @@ namespace Tensorflow.Operations | |||||
| // their associated TensorArrays for calling the body. | // their associated TensorArrays for calling the body. | ||||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | ||||
| var body_result = body(packed_vars_for_body[0]); | var body_result = body(packed_vars_for_body[0]); | ||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| // Store body_result to keep track of TensorArrays returned by body | // Store body_result to keep track of TensorArrays returned by body | ||||
| var original_body_result = new[] { body_result }; | var original_body_result = new[] { body_result }; | ||||
| @@ -0,0 +1,59 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations.Initializers | |||||
| { | |||||
| public class RandomNormal : IInitializer | |||||
| { | |||||
| private float mean; | |||||
| private float stddev; | |||||
| private int? seed; | |||||
| private TF_DataType dtype; | |||||
| public RandomNormal(float mean = 0.0f, | |||||
| float stddev = 1.0f, | |||||
| int? seed = null, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| { | |||||
| this.mean = mean; | |||||
| this.stddev = stddev; | |||||
| this.seed = seed; | |||||
| this.dtype = dtype; | |||||
| } | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = this.dtype; | |||||
| return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed); | |||||
| } | |||||
| public object get_config() | |||||
| { | |||||
| return new | |||||
| { | |||||
| mean, | |||||
| stddev, | |||||
| seed, | |||||
| dtype | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,7 @@ | |||||
| { | { | ||||
| public class Util | public class Util | ||||
| { | { | ||||
| public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES) | |||||
| public static void add_loss(Tensor loss, string loss_collection = "losses") | |||||
| { | { | ||||
| if (!string.IsNullOrEmpty(loss_collection)) | if (!string.IsNullOrEmpty(loss_collection)) | ||||
| ops.add_to_collection(loss_collection, loss); | ops.add_to_collection(loss_collection, loss); | ||||
| @@ -22,7 +22,7 @@ namespace Tensorflow | |||||
| public class LossesImpl | public class LossesImpl | ||||
| { | { | ||||
| public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | ||||
| string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||||
| string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||||
| { | { | ||||
| return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | ||||
| { | { | ||||
| @@ -101,7 +101,7 @@ namespace Tensorflow | |||||
| Tensor logits, | Tensor logits, | ||||
| float weights = 1.0f, | float weights = 1.0f, | ||||
| string scope = null, | string scope = null, | ||||
| string loss_collection= ops.GraphKeys.LOSSES, | |||||
| string loss_collection= "losses", | |||||
| string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(scope, | return tf_with(ops.name_scope(scope, | ||||
| @@ -181,6 +181,31 @@ namespace Tensorflow.Operations | |||||
| return _op.outputs; | return _op.outputs; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Local Response Normalization. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="depth_radius"></param> | |||||
| /// <param name="bias"></param> | |||||
| /// <param name="alpha"></param> | |||||
| /// <param name="beta"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor local_response_normalization(Tensor input, int depth_radius = 5, int bias = 1, | |||||
| int alpha = 1, float beta = 0.5f, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("LRN", name: name, args: new | |||||
| { | |||||
| input, | |||||
| depth_radius, | |||||
| bias, | |||||
| alpha, | |||||
| beta | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| public static Tensor log_softmax(Tensor logits, string name = null) | public static Tensor log_softmax(Tensor logits, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("LogSoftmax", name: name, args: new | var _op = _op_def_lib._apply_op_helper("LogSoftmax", name: name, args: new | ||||
| @@ -189,6 +214,17 @@ namespace Tensorflow.Operations | |||||
| }); | }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | |||||
| public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("LeakyRelu", name: name, args: new | |||||
| { | |||||
| features, | |||||
| alpha | |||||
| }); | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor max_pool(Tensor input, | public static Tensor max_pool(Tensor input, | ||||
| @@ -233,7 +233,7 @@ namespace Tensorflow.Operations | |||||
| dims.AddRange(x_static_shape.dims.Skip(2)); | dims.AddRange(x_static_shape.dims.Skip(2)); | ||||
| var shape = new TensorShape(dims.ToArray()); | var shape = new TensorShape(dims.ToArray()); | ||||
| x_t.SetShape(shape); | |||||
| x_t.set_shape(shape); | |||||
| return x_t; | return x_t; | ||||
| } | } | ||||
| @@ -50,14 +50,12 @@ namespace Tensorflow | |||||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | ||||
| { | { | ||||
| int size = Marshal.SizeOf<TF_Input>(); | |||||
| var handle = Marshal.AllocHGlobal(size); | |||||
| var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Input>()); | |||||
| int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | ||||
| var consumers = new TF_Input[num]; | var consumers = new TF_Input[num]; | ||||
| var inputptr = (TF_Input*) handle; | |||||
| for (int i = 0; i < num; i++) | for (int i = 0; i < num; i++) | ||||
| { | |||||
| consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||||
| } | |||||
| consumers[i] = *(inputptr + i); | |||||
| return consumers; | return consumers; | ||||
| } | } | ||||
| @@ -17,7 +17,9 @@ | |||||
| using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -226,9 +228,12 @@ namespace Tensorflow | |||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| using (var buf = new Buffer()) | using (var buf = new Buffer()) | ||||
| { | { | ||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf); | |||||
| unsafe | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||||
| } | |||||
| } | } | ||||
| string oneof_value = x.ValueCase.ToString(); | string oneof_value = x.ValueCase.ToString(); | ||||
| @@ -259,7 +264,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | c_api.TF_OperationToNodeDef(_handle, buffer, s); | ||||
| s.Check(); | s.Check(); | ||||
| return NodeDef.Parser.ParseFrom(buffer); | |||||
| return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -299,8 +304,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public TF_Output _tf_output(int output_idx) | public TF_Output _tf_output(int output_idx) | ||||
| { | { | ||||
| var tf_output = new TF_Output(op, output_idx); | |||||
| return tf_output; | |||||
| return new TF_Output(op, output_idx); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -308,8 +312,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public TF_Input _tf_input(int input_idx) | public TF_Input _tf_input(int input_idx) | ||||
| { | { | ||||
| var tf_input = new TF_Input(op, input_idx); | |||||
| return tf_input; | |||||
| return new TF_Input(op, input_idx); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -260,8 +260,7 @@ namespace Tensorflow | |||||
| return tf_with(ops.name_scope(name, "ones", new { dims }), scope => | return tf_with(ops.name_scope(name, "ones", new { dims }), scope => | ||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| var shape = ops.convert_to_tensor(dims, dtype: TF_DataType.TF_INT32); | |||||
| var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); | |||||
| var output = _constant_if_small(1, dims, dtype, name); | |||||
| return output; | return output; | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -351,7 +350,7 @@ namespace Tensorflow | |||||
| var input_shape = tensor_util.to_shape(input_tensor.shape); | var input_shape = tensor_util.to_shape(input_tensor.shape); | ||||
| if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | ||||
| { | { | ||||
| var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); | |||||
| var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); | |||||
| return constant_op.constant(nd, name: name); | return constant_op.constant(nd, name: name); | ||||
| } | } | ||||
| } | } | ||||
| @@ -431,8 +431,8 @@ namespace Tensorflow | |||||
| merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); | merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); | ||||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| return merges[0]; | return merges[0]; | ||||
| }); | }); | ||||
| @@ -479,8 +479,8 @@ namespace Tensorflow | |||||
| merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | ||||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| return merges; | return merges; | ||||
| }); | }); | ||||
| @@ -596,7 +596,7 @@ namespace Tensorflow | |||||
| swap_memory: swap_memory); | swap_memory: swap_memory); | ||||
| if (loop_context.outer_context == null) | if (loop_context.outer_context == null) | ||||
| ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); | |||||
| ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | |||||
| var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | ||||
| return_same_structure); | return_same_structure); | ||||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
| public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name= null) | |||||
| public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name= null) | |||||
| { | { | ||||
| if (dtype == image.dtype) | if (dtype == image.dtype) | ||||
| return array_ops.identity(image, name: name); | return array_ops.identity(image, name: name); | ||||
| @@ -57,7 +57,7 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public Tensor decode_jpeg(Tensor contents, | |||||
| public static Tensor decode_jpeg(Tensor contents, | |||||
| int channels = 0, | int channels = 0, | ||||
| int ratio = 1, | int ratio = 1, | ||||
| bool fancy_upscaling = true, | bool fancy_upscaling = true, | ||||
| @@ -88,7 +88,70 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) | |||||
| public static Tensor decode_gif(Tensor contents, | |||||
| string name = null) | |||||
| { | |||||
| // Add nodes to the TensorFlow graph. | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| throw new NotImplementedException("decode_gif"); | |||||
| } | |||||
| else | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("DecodeGif", name: name, args: new | |||||
| { | |||||
| contents | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | |||||
| public static Tensor decode_png(Tensor contents, | |||||
| int channels = 0, | |||||
| TF_DataType dtype = TF_DataType.TF_UINT8, | |||||
| string name = null) | |||||
| { | |||||
| // Add nodes to the TensorFlow graph. | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| throw new NotImplementedException("decode_png"); | |||||
| } | |||||
| else | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("DecodePng", name: name, args: new | |||||
| { | |||||
| contents, | |||||
| channels, | |||||
| dtype | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | |||||
| public static Tensor decode_bmp(Tensor contents, | |||||
| int channels = 0, | |||||
| string name = null) | |||||
| { | |||||
| // Add nodes to the TensorFlow graph. | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| throw new NotImplementedException("decode_bmp"); | |||||
| } | |||||
| else | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("DecodeBmp", name: name, args: new | |||||
| { | |||||
| contents, | |||||
| channels | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | |||||
| public static Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) | |||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| @@ -141,7 +141,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor atan(Tensor x, string name = null) | public static Tensor atan(Tensor x, string name = null) | ||||
| @@ -40,7 +40,7 @@ namespace Tensorflow | |||||
| name: name, | name: name, | ||||
| args: new { shape, dtype, seed, seed2 }); | args: new { shape, dtype, seed, seed2 }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -0,0 +1,42 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class gen_string_ops | |||||
| { | |||||
| static readonly OpDefLibrary _op_def_lib; | |||||
| static gen_string_ops() { _op_def_lib = new OpDefLibrary(); } | |||||
| public static Tensor substr(Tensor input, int pos, int len, | |||||
| string name = null, string @uint = "BYTE") | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Substr", name: name, args: new | |||||
| { | |||||
| input, | |||||
| pos, | |||||
| len, | |||||
| unit = @uint | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,132 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class image_ops_impl | |||||
| { | |||||
| public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, | |||||
| string name = null, bool expand_animations = true) | |||||
| { | |||||
| Tensor substr = null; | |||||
| Func<ITensorOrOperation> _jpeg = () => | |||||
| { | |||||
| int jpeg_channels = channels; | |||||
| var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels"); | |||||
| string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'"; | |||||
| var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||||
| return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate | |||||
| { | |||||
| return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype); | |||||
| }); | |||||
| }; | |||||
| Func<ITensorOrOperation> _gif = () => | |||||
| { | |||||
| int gif_channels = channels; | |||||
| var good_channels = math_ops.logical_and( | |||||
| math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"), | |||||
| math_ops.not_equal(gif_channels, 4, name: "check_gif_channels")); | |||||
| string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images"; | |||||
| var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||||
| return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate | |||||
| { | |||||
| var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype); | |||||
| if (!expand_animations) | |||||
| // result = array_ops.gather(result, 0); | |||||
| throw new NotImplementedException(""); | |||||
| return result; | |||||
| }); | |||||
| }; | |||||
| Func<ITensorOrOperation> _bmp = () => | |||||
| { | |||||
| int bmp_channels = channels; | |||||
| var signature = string_ops.substr(contents, 0, 2); | |||||
| var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | |||||
| string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | |||||
| var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | |||||
| var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels"); | |||||
| string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images"; | |||||
| var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||||
| return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate | |||||
| { | |||||
| return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype); | |||||
| }); | |||||
| }; | |||||
| Func<ITensorOrOperation> _png = () => | |||||
| { | |||||
| return convert_image_dtype(gen_image_ops.decode_png( | |||||
| contents, | |||||
| channels, | |||||
| dtype: dtype), | |||||
| dtype); | |||||
| }; | |||||
| Func<ITensorOrOperation> check_gif = () => | |||||
| { | |||||
| var is_gif = math_ops.equal(substr, "\x47\x49\x46", name: "is_gif"); | |||||
| return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif"); | |||||
| }; | |||||
| Func<ITensorOrOperation> check_png = () => | |||||
| { | |||||
| return control_flow_ops.cond(_is_png(contents), _png, check_gif, name: "cond_png"); | |||||
| }; | |||||
| return tf_with(ops.name_scope(name, "decode_image"), scope => | |||||
| { | |||||
| substr = string_ops.substr(contents, 0, 3); | |||||
| return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | |||||
| }); | |||||
| } | |||||
| public static Tensor is_jpeg(Tensor contents, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "is_jpeg"), scope => | |||||
| { | |||||
| var substr = string_ops.substr(contents, 0, 3); | |||||
| return math_ops.equal(substr, "\xff\xd8\xff", name: name); | |||||
| }); | |||||
| } | |||||
| public static Tensor _is_png(Tensor contents, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "is_png"), scope => | |||||
| { | |||||
| var substr = string_ops.substr(contents, 0, 3); | |||||
| return math_ops.equal(substr, @"\211PN", name: name); | |||||
| }); | |||||
| } | |||||
| public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, | |||||
| string name = null) | |||||
| { | |||||
| if (dtype == image.dtype) | |||||
| return array_ops.identity(image, name: name); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -168,6 +168,9 @@ namespace Tensorflow | |||||
| public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| => gen_math_ops.mul(x, y, name: name); | => gen_math_ops.mul(x, y, name: name); | ||||
| public static Tensor not_equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
| => gen_math_ops.not_equal(x, y, name: name); | |||||
| public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| => gen_math_ops.mul_no_nan(x, y, name: name); | => gen_math_ops.mul_no_nan(x, y, name: name); | ||||
| @@ -264,6 +267,9 @@ namespace Tensorflow | |||||
| return gen_math_ops.log(x, name); | return gen_math_ops.log(x, name); | ||||
| } | } | ||||
| public static Tensor logical_and(Tensor x, Tensor y, string name = null) | |||||
| => gen_math_ops.logical_and(x, y, name: name); | |||||
| public static Tensor lgamma(Tensor x, string name = null) | public static Tensor lgamma(Tensor x, string name = null) | ||||
| => gen_math_ops.lgamma(x, name: name); | => gen_math_ops.lgamma(x, name: name); | ||||
| @@ -98,7 +98,7 @@ namespace Tensorflow | |||||
| // float to be selected, hence we use a >= comparison. | // float to be selected, hence we use a >= comparison. | ||||
| var keep_mask = random_tensor >= rate; | var keep_mask = random_tensor >= rate; | ||||
| var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | ||||
| ret.SetShape(x.TensorShape); | |||||
| ret.set_shape(x.TensorShape); | |||||
| return ret; | return ret; | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -116,6 +116,19 @@ namespace Tensorflow | |||||
| return _softmax(logits, gen_nn_ops.log_softmax, axis, name); | return _softmax(logits, gen_nn_ops.log_softmax, axis, name); | ||||
| } | } | ||||
| public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "LeakyRelu", new { features, alpha }), scope => | |||||
| { | |||||
| name = scope; | |||||
| features = ops.convert_to_tensor(features, name: "features"); | |||||
| if (features.dtype.is_integer()) | |||||
| features = math_ops.cast(features, dtypes.float32); | |||||
| return gen_nn_ops.leaky_relu(features, alpha: alpha, name: name); | |||||
| //return math_ops.maximum(alpha * features, features, name: name); | |||||
| }); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Performs the max pooling on the input. | /// Performs the max pooling on the input. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -39,9 +39,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "random_normal", new { shape, mean, stddev }), scope => | return tf_with(ops.name_scope(name, "random_normal", new { shape, mean, stddev }), scope => | ||||
| { | { | ||||
| name = scope; | |||||
| var shape_tensor = _ShapeTensor(shape); | var shape_tensor = _ShapeTensor(shape); | ||||
| var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); | var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); | ||||
| var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev"); | |||||
| var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev"); | |||||
| var (seed1, seed2) = random_seed.get_seed(seed); | var (seed1, seed2) = random_seed.get_seed(seed); | ||||
| var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | ||||
| var mul = rnd * stddev_tensor; | var mul = rnd * stddev_tensor; | ||||
| @@ -0,0 +1,38 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class string_ops | |||||
| { | |||||
| /// <summary> | |||||
| /// Return substrings from `Tensor` of strings. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="pos"></param> | |||||
| /// <param name="len"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="uint"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor substr(Tensor input, int pos, int len, | |||||
| string name = null, string @uint = "BYTE") | |||||
| => gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||||
| } | |||||
| } | |||||
| @@ -1,408 +1,413 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Numerics; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class BaseSession : DisposableObject | |||||
| { | |||||
| protected Graph _graph; | |||||
| protected bool _opened; | |||||
| protected bool _closed; | |||||
| protected int _current_version; | |||||
| protected byte[] _target; | |||||
| public Graph graph => _graph; | |||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||||
| { | |||||
| _graph = g is null ? ops.get_default_graph() : g; | |||||
| _graph.as_default(); | |||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||||
| SessionOptions newOpts = null; | |||||
| if (opts == null) | |||||
| newOpts = new SessionOptions(); | |||||
| var status = new Status(); | |||||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
| // dispose newOpts | |||||
| if (opts == null) | |||||
| c_api.TF_DeleteSessionOptions(newOpts); | |||||
| status.Check(true); | |||||
| } | |||||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||||
| { | |||||
| _run(op, feed_dict); | |||||
| } | |||||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetche, feed_dict)[0]; | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
| return (results[0], results[1], results[2], results[3]); | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
| return (results[0], results[1], results[2]); | |||||
| } | |||||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
| return (results[0], results[1]); | |||||
| } | |||||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetches, feed_dict); | |||||
| } | |||||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||||
| { | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| return _run(fetches, feed_items); | |||||
| } | |||||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||||
| { | |||||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||||
| var feed_map = new Dictionary<object, object>(); | |||||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||||
| { | |||||
| return new (object, object)[] { (item.Key, item.Value) }; | |||||
| }; | |||||
| // Validate and process feed_dict. | |||||
| if (feed_dict != null) | |||||
| { | |||||
| foreach (var feed in feed_dict) | |||||
| { | |||||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||||
| { | |||||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Create a fetch handler to take care of the structure of fetches. | |||||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
| // Run request and get response. | |||||
| // We need to keep the returned movers alive for the following _do_run(). | |||||
| // These movers are no longer needed when _do_run() completes, and | |||||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||||
| var _ = _update_with_movers(); | |||||
| var final_fetches = fetch_handler.fetches(); | |||||
| var final_targets = fetch_handler.targets(); | |||||
| // We only want to really perform the run if fetches or targets are provided, | |||||
| // or if the call is a partial run that specifies feeds. | |||||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
| return fetch_handler.build_results(this, results); | |||||
| } | |||||
| /// <summary> | |||||
| /// Runs a step based on the given fetches and feeds. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||||
| /// <param name="fetch_list"></param> | |||||
| /// <param name="feed_dict"></param> | |||||
| /// <returns> | |||||
| /// A list of numpy ndarrays, corresponding to the elements of | |||||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||||
| /// name of an operation, the first Tensor output of that operation | |||||
| /// will be returned for that element. | |||||
| /// </returns> | |||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||||
| { | |||||
| var feeds = feed_dict.Select(x => | |||||
| { | |||||
| if (x.Key is Tensor tensor) | |||||
| { | |||||
| switch (x.Value) | |||||
| { | |||||
| #if _REGEN | |||||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case #1[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| % | |||||
| #else | |||||
| case sbyte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case sbyte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| #endif | |||||
| case bool v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||||
| case string v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case IntPtr v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Tensor v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| case NDArray v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| default: | |||||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||||
| } | |||||
| } | |||||
| throw new NotImplementedException("_do_run.feed_dict"); | |||||
| }).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
| var targets = target_list; | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
| } | |||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||||
| { | |||||
| // Ensure any changes to the graph are reflected in the runtime. | |||||
| _extend_graph(); | |||||
| var status = new Status(); | |||||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
| c_api.TF_SessionRun(_handle, | |||||
| run_options: null, | |||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||||
| ninputs: feed_dict.Length, | |||||
| outputs: fetch_list, | |||||
| output_values: output_values, | |||||
| noutputs: fetch_list.Length, | |||||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
| ntargets: target_list.Count, | |||||
| run_metadata: IntPtr.Zero, | |||||
| status: status); | |||||
| status.Check(true); | |||||
| var result = new NDArray[fetch_list.Length]; | |||||
| for (int i = 0; i < fetch_list.Length; i++) | |||||
| result[i] = fetchValue(output_values[i]); | |||||
| for (int i = 0; i < feed_dict.Length; i++) | |||||
| feed_dict[i].Value.Dispose(); | |||||
| return result; | |||||
| } | |||||
| private unsafe NDArray fetchValue(IntPtr output) | |||||
| { | |||||
| var tensor = new Tensor(output); | |||||
| NDArray nd = null; | |||||
| Type type = tensor.dtype.as_numpy_datatype(); | |||||
| var ndims = tensor.shape; | |||||
| var offset = c_api.TF_TensorData(output); | |||||
| if(ndims.Length == 0) | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| nd = NDArray.Scalar(*(bool*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.Data(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = NDArray.FromString(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| nd = NDArray.Scalar(*(byte*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| nd = NDArray.Scalar(*(short*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| nd = NDArray.Scalar(*(int*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| nd = NDArray.Scalar(*(long*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| nd = NDArray.Scalar(*(float*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| nd = NDArray.Scalar(*(double*)offset); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| var bools = new bool[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(bools).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.Data(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = np.array(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| var _bytes = new byte[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(_bytes).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| var shorts = new short[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(shorts).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| var ints = new int[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(ints).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| var longs = new long[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(longs).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| var floats = new float[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(floats).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| var doubles = new double[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(doubles).reshape(ndims); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| tensor.Dispose(); | |||||
| return nd; | |||||
| } | |||||
| /// <summary> | |||||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||||
| /// we move the tensor to the right device, generate a new tensor handle, | |||||
| /// and update feed_dict to use the new handle. | |||||
| /// </summary> | |||||
| private List<object> _update_with_movers() | |||||
| { | |||||
| return new List<object> { }; | |||||
| } | |||||
| private void _extend_graph() | |||||
| { | |||||
| } | |||||
| public void close() | |||||
| { | |||||
| Dispose(); | |||||
| } | |||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| { | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_DeleteSession(handle, status); | |||||
| status.Check(true); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Numerics; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class BaseSession : DisposableObject | |||||
| { | |||||
| protected Graph _graph; | |||||
| protected bool _opened; | |||||
| protected bool _closed; | |||||
| protected int _current_version; | |||||
| protected byte[] _target; | |||||
| public Graph graph => _graph; | |||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||||
| { | |||||
| _graph = g is null ? ops.get_default_graph() : g; | |||||
| _graph.as_default(); | |||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||||
| SessionOptions newOpts = null; | |||||
| if (opts == null) | |||||
| newOpts = new SessionOptions(); | |||||
| var status = new Status(); | |||||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
| // dispose newOpts | |||||
| if (opts == null) | |||||
| newOpts.Dispose(); | |||||
| status.Check(true); | |||||
| } | |||||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||||
| { | |||||
| _run(op, feed_dict); | |||||
| } | |||||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetche, feed_dict)[0]; | |||||
| } | |||||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetche, feed_dict)[0]; | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
| return (results[0], results[1], results[2], results[3]); | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
| return (results[0], results[1], results[2]); | |||||
| } | |||||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
| return (results[0], results[1]); | |||||
| } | |||||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetches, feed_dict); | |||||
| } | |||||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||||
| { | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| return _run(fetches, feed_items); | |||||
| } | |||||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||||
| { | |||||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||||
| var feed_map = new Dictionary<object, object>(); | |||||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||||
| { | |||||
| return new (object, object)[] { (item.Key, item.Value) }; | |||||
| }; | |||||
| // Validate and process feed_dict. | |||||
| if (feed_dict != null) | |||||
| { | |||||
| foreach (var feed in feed_dict) | |||||
| { | |||||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||||
| { | |||||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Create a fetch handler to take care of the structure of fetches. | |||||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
| // Run request and get response. | |||||
| // We need to keep the returned movers alive for the following _do_run(). | |||||
| // These movers are no longer needed when _do_run() completes, and | |||||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||||
| var _ = _update_with_movers(); | |||||
| var final_fetches = fetch_handler.fetches(); | |||||
| var final_targets = fetch_handler.targets(); | |||||
| // We only want to really perform the run if fetches or targets are provided, | |||||
| // or if the call is a partial run that specifies feeds. | |||||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
| return fetch_handler.build_results(this, results); | |||||
| } | |||||
| /// <summary> | |||||
| /// Runs a step based on the given fetches and feeds. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||||
| /// <param name="fetch_list"></param> | |||||
| /// <param name="feed_dict"></param> | |||||
| /// <returns> | |||||
| /// A list of numpy ndarrays, corresponding to the elements of | |||||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||||
| /// name of an operation, the first Tensor output of that operation | |||||
| /// will be returned for that element. | |||||
| /// </returns> | |||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||||
| { | |||||
| var feeds = feed_dict.Select(x => | |||||
| { | |||||
| if (x.Key is Tensor tensor) | |||||
| { | |||||
| switch (x.Value) | |||||
| { | |||||
| #if _REGEN | |||||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case #1[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| % | |||||
| #else | |||||
| case sbyte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case sbyte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| #endif | |||||
| case bool v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||||
| case string v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case IntPtr v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Tensor v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| case NDArray v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| default: | |||||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||||
| } | |||||
| } | |||||
| throw new NotImplementedException("_do_run.feed_dict"); | |||||
| }).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
| var targets = target_list; | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
| } | |||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||||
| { | |||||
| // Ensure any changes to the graph are reflected in the runtime. | |||||
| _extend_graph(); | |||||
| var status = new Status(); | |||||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
| c_api.TF_SessionRun(_handle, | |||||
| run_options: null, | |||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||||
| ninputs: feed_dict.Length, | |||||
| outputs: fetch_list, | |||||
| output_values: output_values, | |||||
| noutputs: fetch_list.Length, | |||||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
| ntargets: target_list.Count, | |||||
| run_metadata: IntPtr.Zero, | |||||
| status: status); | |||||
| status.Check(true); | |||||
| var result = new NDArray[fetch_list.Length]; | |||||
| for (int i = 0; i < fetch_list.Length; i++) | |||||
| result[i] = fetchValue(output_values[i]); | |||||
| for (int i = 0; i < feed_dict.Length; i++) | |||||
| feed_dict[i].Value.Dispose(); | |||||
| return result; | |||||
| } | |||||
| private unsafe NDArray fetchValue(IntPtr output) | |||||
| { | |||||
| var tensor = new Tensor(output); | |||||
| NDArray nd = null; | |||||
| Type type = tensor.dtype.as_numpy_dtype(); | |||||
| var ndims = tensor.shape; | |||||
| var offset = c_api.TF_TensorData(output); | |||||
| if(ndims.Length == 0) | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| nd = NDArray.Scalar(*(bool*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = NDArray.FromString(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| nd = NDArray.Scalar(*(byte*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| nd = NDArray.Scalar(*(short*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| nd = NDArray.Scalar(*(int*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| nd = NDArray.Scalar(*(long*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| nd = NDArray.Scalar(*(float*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| nd = NDArray.Scalar(*(double*)offset); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| var bools = new bool[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(bools).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = np.array(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| var _bytes = new byte[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(_bytes).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| var shorts = new short[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(shorts).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| var ints = new int[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(ints).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| var longs = new long[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(longs).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| var floats = new float[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(floats).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| var doubles = new double[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(doubles).reshape(ndims); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| tensor.Dispose(); | |||||
| return nd; | |||||
| } | |||||
| /// <summary> | |||||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||||
| /// we move the tensor to the right device, generate a new tensor handle, | |||||
| /// and update feed_dict to use the new handle. | |||||
| /// </summary> | |||||
| private List<object> _update_with_movers() | |||||
| { | |||||
| return new List<object> { }; | |||||
| } | |||||
| private void _extend_graph() | |||||
| { | |||||
| } | |||||
| public void close() | |||||
| { | |||||
| Dispose(); | |||||
| } | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_DeleteSession(handle, status); | |||||
| status.Check(true); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,5 +16,11 @@ | |||||
| public static implicit operator FeedItem((object, object) feed) | public static implicit operator FeedItem((object, object) feed) | ||||
| => new FeedItem(feed.Item1, feed.Item2); | => new FeedItem(feed.Item1, feed.Item2); | ||||
| public void Deconstruct(out object key, out object value) | |||||
| { | |||||
| key = Key; | |||||
| value = Value; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -32,13 +32,13 @@ namespace Tensorflow | |||||
| _handle = handle; | _handle = handle; | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteSessionOptions(handle); | => c_api.TF_DeleteSessionOptions(handle); | ||||
| public void SetConfig(ConfigProto config) | public void SetConfig(ConfigProto config) | ||||
| { | { | ||||
| var bytes = config.ToByteArray(); | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||||
| var bytes = config.ToByteArray(); //TODO! we can use WriteTo | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| @@ -17,6 +17,7 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using NumSharp.Backends; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -71,18 +72,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| if(tensor_values.Length > 0) | if(tensor_values.Length > 0) | ||||
| { | { | ||||
| switch (tensor_values[0].dtype.Name) | |||||
| switch (tensor_values[0].typecode) | |||||
| { | { | ||||
| case "Int32": | |||||
| case NPTypeCode.Int32: | |||||
| full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
| break; | break; | ||||
| case "Single": | |||||
| case NPTypeCode.Single: | |||||
| full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
| break; | break; | ||||
| case "String": | |||||
| case NPTypeCode.String: | |||||
| full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
| break; | break; | ||||
| case "Char": | |||||
| case NPTypeCode.Char: | |||||
| full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
| break; | break; | ||||
| default: | default: | ||||
| @@ -100,21 +101,21 @@ namespace Tensorflow | |||||
| j += 1; | j += 1; | ||||
| if (value.ndim == 0) | if (value.ndim == 0) | ||||
| { | { | ||||
| switch (value.dtype.Name) | |||||
| switch (value.typecode) | |||||
| { | { | ||||
| case "Int16": | |||||
| case NPTypeCode.Int16: | |||||
| full_values.Add(value.GetValue<short>(0)); | full_values.Add(value.GetValue<short>(0)); | ||||
| break; | break; | ||||
| case "Int32": | |||||
| case NPTypeCode.Int32: | |||||
| full_values.Add(value.GetValue<int>(0)); | full_values.Add(value.GetValue<int>(0)); | ||||
| break; | break; | ||||
| case "Int64": | |||||
| case NPTypeCode.Int64: | |||||
| full_values.Add(value.GetValue<long>(0)); | full_values.Add(value.GetValue<long>(0)); | ||||
| break; | break; | ||||
| case "Single": | |||||
| case NPTypeCode.Single: | |||||
| full_values.Add(value.GetValue<float>(0)); | full_values.Add(value.GetValue<float>(0)); | ||||
| break; | break; | ||||
| case "Double": | |||||
| case NPTypeCode.Double: | |||||
| full_values.Add(value.GetValue<double>(0)); | full_values.Add(value.GetValue<double>(0)); | ||||
| break; | break; | ||||
| /*case "String": | /*case "String": | ||||
| @@ -27,13 +27,17 @@ namespace Tensorflow | |||||
| var handle = Marshal.AllocHGlobal(size * num_consumers); | var handle = Marshal.AllocHGlobal(size * num_consumers); | ||||
| int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); | int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); | ||||
| var consumers = new string[num_consumers]; | var consumers = new string[num_consumers]; | ||||
| for (int i = 0; i < num; i++) | |||||
| unsafe | |||||
| { | { | ||||
| TF_Input input = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||||
| consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper)); | |||||
| var inputptr = (TF_Input*) handle; | |||||
| for (int i = 0; i < num; i++) | |||||
| { | |||||
| var oper = (inputptr + i)->oper; | |||||
| consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); | |||||
| } | |||||
| } | } | ||||
| return consumers; | return consumers; | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Runtime.CompilerServices; | |||||
| using static Tensorflow.c_api; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -27,36 +29,36 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Error message | /// Error message | ||||
| /// </summary> | /// </summary> | ||||
| public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); | |||||
| public string Message => c_api.StringPiece(TF_Message(_handle)); | |||||
| /// <summary> | /// <summary> | ||||
| /// Error code | /// Error code | ||||
| /// </summary> | /// </summary> | ||||
| public TF_Code Code => c_api.TF_GetCode(_handle); | |||||
| public TF_Code Code => TF_GetCode(_handle); | |||||
| public Status() | public Status() | ||||
| { | { | ||||
| _handle = c_api.TF_NewStatus(); | |||||
| _handle = TF_NewStatus(); | |||||
| } | } | ||||
| public void SetStatus(TF_Code code, string msg) | public void SetStatus(TF_Code code, string msg) | ||||
| { | { | ||||
| c_api.TF_SetStatus(_handle, code, msg); | |||||
| TF_SetStatus(_handle, code, msg); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Check status | /// Check status | ||||
| /// Throw exception with error message if code != TF_OK | /// Throw exception with error message if code != TF_OK | ||||
| /// </summary> | /// </summary> | ||||
| /// <exception cref="TensorflowException">When the returned check is not TF_Code.TF_OK</exception> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public void Check(bool throwException = false) | public void Check(bool throwException = false) | ||||
| { | { | ||||
| if(Code != TF_Code.TF_OK) | |||||
| if (Code != TF_Code.TF_OK) | |||||
| { | { | ||||
| Console.WriteLine(Message); | Console.WriteLine(Message); | ||||
| if (throwException) | if (throwException) | ||||
| { | |||||
| throw new Exception(Message); | |||||
| } | |||||
| throw new TensorflowException(Message); | |||||
| } | } | ||||
| } | } | ||||
| @@ -65,7 +67,7 @@ namespace Tensorflow | |||||
| return status._handle; | return status._handle; | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| => c_api.TF_DeleteStatus(handle); | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => TF_DeleteStatus(handle); | |||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -51,7 +51,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_NewStatus(); | |||||
| public static extern IntPtr TF_NewStatus(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Record <code, msg> in *s. Any previous information is lost. | /// Record <code, msg> in *s. Any previous information is lost. | ||||
| @@ -33,11 +33,11 @@ namespace Tensorflow.Summaries | |||||
| { | { | ||||
| var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); | var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); | ||||
| var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); | var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); | ||||
| collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES }); | |||||
| collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES }); | |||||
| return val; | return val; | ||||
| } | } | ||||
| public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null) | |||||
| public Tensor merge_all(string key = "summaries", string scope= null, string name= null) | |||||
| { | { | ||||
| var summary_ops = ops.get_collection(key, scope: scope); | var summary_ops = ops.get_collection(key, scope: scope); | ||||
| if (summary_ops == null) | if (summary_ops == null) | ||||
| @@ -67,7 +67,7 @@ namespace Tensorflow.Summaries | |||||
| { | { | ||||
| var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); | var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); | ||||
| var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); | var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); | ||||
| collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES }); | |||||
| collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES }); | |||||
| return val; | return val; | ||||
| } | } | ||||
| @@ -5,8 +5,8 @@ | |||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>1.14.0</TargetTensorFlow> | <TargetTensorFlow>1.14.0</TargetTensorFlow> | ||||
| <Version>0.11.0</Version> | |||||
| <Authors>Haiping Chen, Meinrad Recheis</Authors> | |||||
| <Version>0.11.1</Version> | |||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| <Copyright>Apache 2.0</Copyright> | <Copyright>Apache 2.0</Copyright> | ||||
| @@ -17,10 +17,16 @@ | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | ||||
| <Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.11.10.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Changes since v0.10.0:</PackageReleaseNotes> | |||||
| <AssemblyVersion>0.11.1.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Changes since v0.10.0: | |||||
| 1. Upgrade NumSharp to v0.20. | |||||
| 2. Add DisposableObject class to manage object lifetime. | |||||
| 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. | |||||
| 4. Change tensorflow to non-static class in order to execute some initialization process. | |||||
| 5. Overload session.run(), make syntax simpler. | |||||
| 6. Add Local Response Normalization.</PackageReleaseNotes> | |||||
| <LangVersion>7.3</LangVersion> | <LangVersion>7.3</LangVersion> | ||||
| <FileVersion>0.11.10.0</FileVersion> | |||||
| <FileVersion>0.11.1.0</FileVersion> | |||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
| <SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
| @@ -52,7 +58,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | ||||
| <PackageReference Include="NumSharp" Version="0.20.0-alpha1" /> | |||||
| <PackageReference Include="NumSharp" Version="0.20.0" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -16,11 +16,13 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Numerics; | using System.Numerics; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | using NumSharp.Backends.Unmanaged; | ||||
| using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
| @@ -50,9 +52,9 @@ namespace Tensorflow | |||||
| private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; | private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; | ||||
| // note: they must be assigned to a static variable in order to work as unmanaged callbacks | // note: they must be assigned to a static variable in order to work as unmanaged callbacks | ||||
| static Deallocator _hGlobalDeallocator = FreeHGlobalMemory; | |||||
| static Deallocator _gcHandleDeallocator = FreeGCHandle; | |||||
| private static Deallocator _nothingDeallocator = FreeNothing; | |||||
| private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory; | |||||
| private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle; | |||||
| private static readonly Deallocator _nothingDeallocator = FreeNothing; | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a Tensor object from an existing TF handle | /// Create a Tensor object from an existing TF handle | ||||
| @@ -462,7 +464,7 @@ namespace Tensorflow | |||||
| *v = value; | *v = value; | ||||
| _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); | _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); | ||||
| IsMemoryOwner=true; | IsMemoryOwner=true; | ||||
| } | |||||
| } | |||||
| #endif | #endif | ||||
| /// <summary> | /// <summary> | ||||
| @@ -477,7 +479,7 @@ namespace Tensorflow | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
| Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
| fixed (byte* src = &buffer[0]) | |||||
| fixed (byte* src = buffer) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | ||||
| _handle = handle; | _handle = handle; | ||||
| status.Check(true); | status.Check(true); | ||||
| @@ -486,35 +488,54 @@ namespace Tensorflow | |||||
| public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | ||||
| { | { | ||||
| // todo: handle nd of type "String" here too | // todo: handle nd of type "String" here too | ||||
| if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") | |||||
| if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) | |||||
| { | { | ||||
| var buffer = nd.ToArray<byte>(); | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| if (nd.Unsafe.Storage.Shape.IsContiguous) | |||||
| { | |||||
| var bytesLength = (UIntPtr)nd.size; | |||||
| var size = c_api.TF_StringEncodedSize(bytesLength); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| var status = new Status(); | |||||
| c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle = handle; | |||||
| IsMemoryOwner = false; | |||||
| } | |||||
| else | |||||
| { | |||||
| var buffer = nd.ToArray<byte>(); | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| var status = new Status(); | |||||
| fixed (byte* src = buffer) | |||||
| c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle = handle; | |||||
| IsMemoryOwner = false; | |||||
| } | |||||
| var status = new Status(); | |||||
| fixed (byte* src = &buffer[0]) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle=handle; | |||||
| IsMemoryOwner = false; | |||||
| return; | return; | ||||
| } | } | ||||
| _handle = CreateTensorFromNDArray(nd, tensorDType); | _handle = CreateTensorFromNDArray(nd, tensorDType); | ||||
| IsMemoryOwner = true; | |||||
| } | } | ||||
| private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | ||||
| { | { | ||||
| if (nd.dtype.Name == "String") | |||||
| throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||||
| if (nd.dtype.Name == "String") | |||||
| throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||||
| IArraySlice arraySlice; | IArraySlice arraySlice; | ||||
| var shape = nd.Unsafe.Storage.Shape; | |||||
| if (shape.IsSliced || shape.IsBroadcasted) | |||||
| if (nd.Unsafe.Storage.Shape.IsContiguous == false) | |||||
| { | { | ||||
| // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. | // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. | ||||
| arraySlice = nd.CloneData(); | arraySlice = nd.CloneData(); | ||||
| @@ -527,51 +548,52 @@ namespace Tensorflow | |||||
| this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | ||||
| var ptr = new IntPtr(arraySlice.Address); | var ptr = new IntPtr(arraySlice.Address); | ||||
| int num_bytes = (nd.size * nd.dtypesize); | int num_bytes = (nd.size * nd.dtypesize); | ||||
| var dtype = given_dtype ?? ToTFDataType(nd.dtype); | |||||
| var dtype = given_dtype ?? nd.dtype.as_dtype(); | |||||
| var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | ||||
| IsMemoryOwner = false; | IsMemoryOwner = false; | ||||
| return handle; | return handle; | ||||
| } | |||||
| public unsafe Tensor(byte[][] buffer, long[] shape) | |||||
| { | |||||
| int size = 0; | |||||
| foreach (var b in buffer) | |||||
| { | |||||
| size += (int)TF_StringEncodedSize((UIntPtr)b.Length); | |||||
| } | |||||
| int totalSize = size + buffer.Length * 8; | |||||
| ulong offset = 0; | |||||
| IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); | |||||
| // Clear offset table | |||||
| IntPtr pOffset = TF_TensorData(handle); | |||||
| IntPtr dst = pOffset + buffer.Length * 8; | |||||
| IntPtr dstLimit = pOffset + totalSize; | |||||
| for (int i = 0; i < buffer.Length; i++) | |||||
| { | |||||
| Marshal.WriteInt64(pOffset, (long)offset); | |||||
| using (var status = new Status()) | |||||
| { | |||||
| fixed (byte* src = &buffer[i][0]) | |||||
| { | |||||
| var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); | |||||
| status.Check(true); | |||||
| pOffset += 8; | |||||
| dst += (int)written; | |||||
| offset += written; | |||||
| } | |||||
| } | |||||
| } | |||||
| _handle = handle; | |||||
| } | |||||
| public unsafe Tensor(byte[][] buffer, long[] shape) | |||||
| { | |||||
| int size = 0; | |||||
| foreach (var b in buffer) | |||||
| { | |||||
| size += (int)TF_StringEncodedSize((UIntPtr)b.Length); | |||||
| } | |||||
| int totalSize = size + buffer.Length * 8; | |||||
| ulong offset = 0; | |||||
| IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); | |||||
| // Clear offset table | |||||
| IntPtr pOffset = TF_TensorData(handle); | |||||
| IntPtr dst = pOffset + buffer.Length * 8; | |||||
| IntPtr dstLimit = pOffset + totalSize; | |||||
| for (int i = 0; i < buffer.Length; i++) | |||||
| { | |||||
| Marshal.WriteInt64(pOffset, (long)offset); | |||||
| using (var status = new Status()) | |||||
| { | |||||
| fixed (byte* src = &buffer[i][0]) | |||||
| { | |||||
| var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); | |||||
| status.Check(true); | |||||
| pOffset += 8; | |||||
| dst += (int)written; | |||||
| offset += written; | |||||
| } | |||||
| } | |||||
| } | |||||
| _handle = handle; | |||||
| } | } | ||||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | public Tensor(Operation op, int value_index, TF_DataType dtype) | ||||
| { | { | ||||
| _op = op; | _op = op; | ||||
| _value_index = value_index; | _value_index = value_index; | ||||
| _dtype = dtype; | |||||
| _override_dtype = dtype; | |||||
| _id = ops.uid(); | _id = ops.uid(); | ||||
| } | } | ||||
| @@ -589,11 +611,11 @@ namespace Tensorflow | |||||
| /// specified dimensions. | /// specified dimensions. | ||||
| /// </remarks> | /// </remarks> | ||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | [MethodImpl(MethodImplOptions.AggressiveInlining)] | ||||
| [SuppressMessage("ReSharper", "LocalVariableHidesMember")] | |||||
| protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) | protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) | ||||
| { | { | ||||
| if (dt == TF_DataType.TF_STRING && data is byte[]) | |||||
| if (dt == TF_DataType.TF_STRING && data is byte[] buffer) | |||||
| { | { | ||||
| var buffer = (byte[])data; | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | ||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | ||||
| @@ -601,7 +623,7 @@ namespace Tensorflow | |||||
| Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
| var status = new Status(); | var status = new Status(); | ||||
| fixed (byte* src = &buffer[0]) | |||||
| fixed (byte* src = buffer) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| @@ -644,8 +666,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (args.deallocator_called) | if (args.deallocator_called) | ||||
| return; | return; | ||||
| // NumSharp will dispose | // NumSharp will dispose | ||||
| // Marshal.FreeHGlobal(dataPtr); | |||||
| Marshal.FreeHGlobal(dataPtr); | |||||
| args.deallocator_called = true; | args.deallocator_called = true; | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Runtime.CompilerServices; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -6,86 +7,142 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static explicit operator bool(Tensor tensor) | public static explicit operator bool(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<bool>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_BOOL); | |||||
| return *(bool*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator sbyte(Tensor tensor) | public static explicit operator sbyte(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<sbyte>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT8); | |||||
| return *(sbyte*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator byte(Tensor tensor) | public static explicit operator byte(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<byte>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT8); | |||||
| return *(byte*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator ushort(Tensor tensor) | public static explicit operator ushort(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<ushort>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT16); | |||||
| return *(ushort*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator short(Tensor tensor) | public static explicit operator short(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<short>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT16); | |||||
| return *(short*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator int(Tensor tensor) | public static explicit operator int(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<int>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT32); | |||||
| return *(int*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator uint(Tensor tensor) | public static explicit operator uint(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<uint>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT32); | |||||
| return *(uint*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator long(Tensor tensor) | public static explicit operator long(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<long>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT64); | |||||
| return *(long*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator ulong(Tensor tensor) | public static explicit operator ulong(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<ulong>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT64); | |||||
| return *(ulong*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator float(Tensor tensor) | public static explicit operator float(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<float>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||||
| return *(float*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator double(Tensor tensor) | public static explicit operator double(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<double>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||||
| return *(double*) tensor.buffer; | |||||
| } | |||||
| } | |||||
| public static explicit operator string(Tensor tensor) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_STRING); | |||||
| return new string((char*) tensor.buffer, 0, (int) tensor.size); | |||||
| } | |||||
| } | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| private static void EnsureDType(Tensor tensor, TF_DataType @is) | |||||
| { | |||||
| if (tensor.dtype != @is) | |||||
| throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}"); | |||||
| } | } | ||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| private static void EnsureScalar(Tensor tensor) | private static void EnsureScalar(Tensor tensor) | ||||
| { | { | ||||
| if (tensor == null) | if (tensor == null) | ||||
| { | |||||
| throw new ArgumentNullException(nameof(tensor)); | throw new ArgumentNullException(nameof(tensor)); | ||||
| } | |||||
| if (tensor.TensorShape.ndim != 0) | if (tensor.TensorShape.ndim != 0) | ||||
| { | |||||
| throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | ||||
| } | |||||
| if (tensor.TensorShape.size != 1) | if (tensor.TensorShape.size != 1) | ||||
| { | |||||
| throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | ||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -69,11 +69,12 @@ namespace Tensorflow | |||||
| TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, | TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, | ||||
| TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | ||||
| }; | }; | ||||
| public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); | public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); | ||||
| public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); | public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); | ||||
| public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); | public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); | ||||
| public static Tensor operator /(Tensor x, Tensor y) => | public static Tensor operator /(Tensor x, Tensor y) => | ||||
| _intTfDataTypes.Contains(x._dtype) | |||||
| _intTfDataTypes.Contains(x.dtype) | |||||
| ? BinaryOpWrapper("floordiv", x, y) | ? BinaryOpWrapper("floordiv", x, y) | ||||
| : BinaryOpWrapper("truediv", x, y); | : BinaryOpWrapper("truediv", x, y); | ||||
| public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); | public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); | ||||
| @@ -122,8 +123,7 @@ namespace Tensorflow | |||||
| if (y is Tensor tr) | if (y is Tensor tr) | ||||
| dtype = tr.dtype.as_base_dtype(); | dtype = tr.dtype.as_base_dtype(); | ||||
| var namescope = ops.name_scope(null, name, new { x, y }); | |||||
| return tf_with(namescope, scope => | |||||
| return tf_with(ops.name_scope(null, name, new { x, y }), scope => | |||||
| { | { | ||||
| Tensor result = null; | Tensor result = null; | ||||
| var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | ||||
| @@ -155,7 +155,6 @@ namespace Tensorflow | |||||
| return result; | return result; | ||||
| }); | }); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -17,9 +17,16 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Globalization; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading.Tasks; | |||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | |||||
| using NumSharp.Utilities; | |||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -29,42 +36,68 @@ namespace Tensorflow | |||||
| /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | ||||
| /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | ||||
| /// </summary> | /// </summary> | ||||
| [SuppressMessage("ReSharper", "ConvertToAutoProperty")] | |||||
| public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | ||||
| { | { | ||||
| private int _id; | |||||
| private Operation _op; | |||||
| private readonly int _id; | |||||
| private readonly Operation _op; | |||||
| private readonly int _value_index; | |||||
| private TF_Output? _tf_output; | |||||
| private readonly TF_DataType _override_dtype; | |||||
| public int Id => _id; | public int Id => _id; | ||||
| /// <summary> | |||||
| /// The Graph that contains this tensor. | |||||
| /// </summary> | |||||
| public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
| /// <summary> | |||||
| /// The Operation that produces this tensor as an output. | |||||
| /// </summary> | |||||
| public Operation op => _op; | public Operation op => _op; | ||||
| public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
| /// <summary> | /// <summary> | ||||
| /// The string name of this tensor. | |||||
| /// The string name of this tensor. | |||||
| /// </summary> | /// </summary> | ||||
| public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; | public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; | ||||
| private int _value_index; | |||||
| /// <summary> | |||||
| /// The index of this tensor in the outputs of its Operation. | |||||
| /// </summary> | |||||
| public int value_index => _value_index; | public int value_index => _value_index; | ||||
| private TF_DataType _dtype = TF_DataType.DtInvalid; | |||||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | |||||
| /// <summary> | |||||
| /// The DType of elements in this tensor. | |||||
| /// </summary> | |||||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | |||||
| public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | ||||
| public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | ||||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | ||||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
| public int NDims => rank; | |||||
| private TF_Output? _tf_output; | |||||
| /// <summary> | |||||
| /// The name of the device on which this tensor will be produced, or null. | |||||
| /// </summary> | |||||
| public string Device => op.Device; | |||||
| public int[] dims => shape; | |||||
| /// <summary> | /// <summary> | ||||
| /// used for keep other pointer when do implicit operating | |||||
| /// Used for keep other pointer when do implicit operating | |||||
| /// </summary> | /// </summary> | ||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| /// <summary> | |||||
| /// Returns the shape of a tensor. | |||||
| /// </summary> | |||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks> | |||||
| public int[] shape | public int[] shape | ||||
| { | { | ||||
| get | get | ||||
| @@ -76,14 +109,13 @@ namespace Tensorflow | |||||
| var status = new Status(); | var status = new Status(); | ||||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | ||||
| status.Check(); | status.Check(); | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| for (int i = 0; i < rank; i++) | for (int i = 0; i < rank; i++) | ||||
| dims[i] = c_api.TF_Dim(_handle, i); | dims[i] = c_api.TF_Dim(_handle, i); | ||||
| } | } | ||||
| return dims.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
| return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); | |||||
| } | } | ||||
| set | set | ||||
| @@ -93,38 +125,52 @@ namespace Tensorflow | |||||
| if (value == null) | if (value == null) | ||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | ||||
| else | else | ||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status); | |||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
| } | } | ||||
| } | } | ||||
| public int[] _shape_tuple() | public int[] _shape_tuple() | ||||
| { | { | ||||
| if (shape == null) return null; | |||||
| return shape.Select(x => (int)x).ToArray(); | |||||
| return (int[]) shape.Clone(); | |||||
| } | } | ||||
| public TensorShape TensorShape => tensor_util.to_shape(shape); | public TensorShape TensorShape => tensor_util.to_shape(shape); | ||||
| public void SetShape(TensorShape shape) | |||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| public void set_shape(TensorShape shape) | |||||
| { | { | ||||
| this.shape = shape.dims; | |||||
| this.shape = (int[]) shape.dims.Clone(); | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
| public void SetShape(TensorShape shape) | |||||
| { | |||||
| this.shape = (int[]) shape.dims.Clone(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| public void set_shape(Tensor shape) | public void set_shape(Tensor shape) | ||||
| { | { | ||||
| // ReSharper disable once MergeConditionalExpression | |||||
| this.shape = shape is null ? null : shape.shape; | this.shape = shape is null ? null : shape.shape; | ||||
| } | } | ||||
| public int[] dims => shape; | |||||
| /// <summary> | /// <summary> | ||||
| /// number of dimensions | |||||
| /// 0 Scalar (magnitude only) | |||||
| /// 1 Vector (magnitude and direction) | |||||
| /// 2 Matrix (table of numbers) | |||||
| /// 3 3-Tensor (cube of numbers) | |||||
| /// number of dimensions <br></br> | |||||
| /// 0 Scalar (magnitude only) <br></br> | |||||
| /// 1 Vector (magnitude and direction) <br></br> | |||||
| /// 2 Matrix (table of numbers) <br></br> | |||||
| /// 3 3-Tensor (cube of numbers) <br></br> | |||||
| /// n n-Tensor (you get the idea) | /// n n-Tensor (you get the idea) | ||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/rank</remarks> | |||||
| public int rank | public int rank | ||||
| { | { | ||||
| get | get | ||||
| @@ -137,17 +183,15 @@ namespace Tensorflow | |||||
| status.Check(); | status.Check(); | ||||
| return ndim; | return ndim; | ||||
| } | } | ||||
| else | |||||
| { | |||||
| return c_api.TF_NumDims(_handle); | |||||
| } | |||||
| return c_api.TF_NumDims(_handle); | |||||
| } | } | ||||
| } | } | ||||
| public int NDims => rank; | |||||
| public string Device => op.Device; | |||||
| /// <summary> | |||||
| /// Returns a list of Operations that consume this tensor. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public Operation[] consumers() | public Operation[] consumers() | ||||
| { | { | ||||
| var output = _as_tf_output(); | var output = _as_tf_output(); | ||||
| @@ -157,38 +201,181 @@ namespace Tensorflow | |||||
| public TF_Output _as_tf_output() | public TF_Output _as_tf_output() | ||||
| { | { | ||||
| if(!_tf_output.HasValue) | |||||
| if (!_tf_output.HasValue) | |||||
| _tf_output = new TF_Output(op, value_index); | _tf_output = new TF_Output(op, value_index); | ||||
| return _tf_output.Value; | return _tf_output.Value; | ||||
| } | } | ||||
| public T[] Data<T>() | |||||
| [Obsolete("Please use ToArray<T>() instead.", false)] | |||||
| public T[] Data<T>() where T : unmanaged | |||||
| { | { | ||||
| // Column major order | |||||
| // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg | |||||
| // matrix:[[1, 2, 3], [4, 5, 6]] | |||||
| // index: 0 2 4 1 3 5 | |||||
| // result: 1 4 2 5 3 6 | |||||
| var data = new T[size]; | |||||
| for (ulong i = 0; i < size; i++) | |||||
| return ToArray<T>(); | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception> | |||||
| public T[] ToArray<T>() where T : unmanaged | |||||
| { | |||||
| //Are the types matching? | |||||
| if (typeof(T).as_dtype() == dtype) | |||||
| { | { | ||||
| data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * itemsize)); | |||||
| } | |||||
| if (NDims == 0 && size == 1) //is it a scalar? | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new T[] {*(T*) buffer}; | |||||
| } | |||||
| } | |||||
| //types match, no need to perform cast | |||||
| var ret = new T[size]; | |||||
| unsafe | |||||
| { | |||||
| var len = (long) size; | |||||
| fixed (T* dst = ret) | |||||
| { | |||||
| //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. | |||||
| var src = (T*) buffer; | |||||
| len *= ((long) itemsize); | |||||
| System.Buffer.MemoryCopy(src, dst, len, len); | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } else | |||||
| { | |||||
| //types do not match, need to perform cast | |||||
| if (NDims == 0 && size == 1) //is it a scalar? | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)}; | |||||
| % | |||||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (dtype.as_numpy_dtype()?.GetTypeCode()) | |||||
| { | |||||
| case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)}; | |||||
| case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)}; | |||||
| case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer, NPTypeCode.Int16)}; | |||||
| case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer, NPTypeCode.UInt16)}; | |||||
| case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer, NPTypeCode.Int32)}; | |||||
| case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer, NPTypeCode.UInt32)}; | |||||
| case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer, NPTypeCode.Int64)}; | |||||
| case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer, NPTypeCode.UInt64)}; | |||||
| case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer, NPTypeCode.Char)}; | |||||
| case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer, NPTypeCode.Double)}; | |||||
| case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)}; | |||||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| } | |||||
| return data; | |||||
| var ret = new T[size]; | |||||
| unsafe | |||||
| { | |||||
| var len = (long) size; | |||||
| fixed (T* dstRet = ret) | |||||
| { | |||||
| T* dst = dstRet; //local stack copy | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| % | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T> | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Copies the memory of current buffer onto newly allocated array. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
| public byte[] Data() | public byte[] Data() | ||||
| { | { | ||||
| var data = new byte[bytesize]; | |||||
| Marshal.Copy(buffer, data, 0, (int)bytesize); | |||||
| return data; | |||||
| return BufferToArray(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Copies the memory of current buffer onto newly allocated array. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public byte[] BufferToArray() | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| // ReSharper disable once LocalVariableHidesMember | |||||
| var bytesize = (long) this.bytesize; | |||||
| var data = new byte[bytesize]; | |||||
| fixed (byte* dst = data) | |||||
| System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); | |||||
| return data; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Extracts string array from current Tensor. | |||||
| /// </summary> | |||||
| /// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception> | |||||
| public unsafe string[] StringData() | public unsafe string[] StringData() | ||||
| { | { | ||||
| if (dtype != TF_DataType.TF_STRING) | |||||
| throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); | |||||
| // | // | ||||
| // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | ||||
| // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | ||||
| @@ -199,19 +386,19 @@ namespace Tensorflow | |||||
| var buffer = new byte[size][]; | var buffer = new byte[size][]; | ||||
| var src = c_api.TF_TensorData(_handle); | var src = c_api.TF_TensorData(_handle); | ||||
| var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||||
| src += (int)(size * 8); | |||||
| var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize); | |||||
| src += (int) (size * 8); | |||||
| for (int i = 0; i < buffer.Length; i++) | for (int i = 0; i < buffer.Length; i++) | ||||
| { | { | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| IntPtr dst = IntPtr.Zero; | IntPtr dst = IntPtr.Zero; | ||||
| UIntPtr dstLen = UIntPtr.Zero; | UIntPtr dstLen = UIntPtr.Zero; | ||||
| var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); | |||||
| var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status); | |||||
| status.Check(true); | status.Check(true); | ||||
| buffer[i] = new byte[(int)dstLen]; | |||||
| buffer[i] = new byte[(int) dstLen]; | |||||
| Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | ||||
| src += (int)read; | |||||
| src += (int) read; | |||||
| } | } | ||||
| } | } | ||||
| @@ -229,51 +416,29 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | ||||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||||
| /// <returns></returns> | |||||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||||
| public NDArray eval(params FeedItem[] feed_dict) | public NDArray eval(params FeedItem[] feed_dict) | ||||
| { | { | ||||
| return ops._eval_using_default_session(this, feed_dict, graph); | return ops._eval_using_default_session(this, feed_dict, graph); | ||||
| } | } | ||||
| public NDArray eval(Session session, FeedItem[] feed_dict = null) | |||||
| /// <summary> | |||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// </summary> | |||||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | |||||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||||
| public NDArray eval(Session session, params FeedItem[] feed_dict) | |||||
| { | { | ||||
| return ops._eval_using_default_session(this, feed_dict, graph, session); | return ops._eval_using_default_session(this, feed_dict, graph, session); | ||||
| } | } | ||||
| public TF_DataType ToTFDataType(Type type) | |||||
| { | |||||
| switch (type.Name) | |||||
| { | |||||
| case "Char": | |||||
| return TF_DataType.TF_UINT8; | |||||
| case "Int16": | |||||
| return TF_DataType.TF_INT16; | |||||
| case "Int32": | |||||
| return TF_DataType.TF_INT32; | |||||
| case "Int64": | |||||
| return TF_DataType.TF_INT64; | |||||
| case "Single": | |||||
| return TF_DataType.TF_FLOAT; | |||||
| case "Double": | |||||
| return TF_DataType.TF_DOUBLE; | |||||
| case "Byte": | |||||
| return TF_DataType.TF_UINT8; | |||||
| case "String": | |||||
| return TF_DataType.TF_STRING; | |||||
| case "Boolean": | |||||
| return TF_DataType.TF_BOOL; | |||||
| default: | |||||
| throw new NotImplementedException("ToTFDataType error"); | |||||
| } | |||||
| } | |||||
| public Tensor slice(Slice slice) | public Tensor slice(Slice slice) | ||||
| { | { | ||||
| var slice_spec = new int[] { slice.Start.Value }; | |||||
| var slice_spec = new int[] {slice.Start.Value}; | |||||
| var begin = new List<int>(); | var begin = new List<int>(); | ||||
| var end = new List<int>(); | var end = new List<int>(); | ||||
| var strides = new List<int>(); | var strides = new List<int>(); | ||||
| @@ -289,26 +454,26 @@ namespace Tensorflow | |||||
| if (slice.Stop.HasValue) | if (slice.Stop.HasValue) | ||||
| { | { | ||||
| end.Add(slice.Stop.Value); | end.Add(slice.Stop.Value); | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| end.Add(0); | end.Add(0); | ||||
| end_mask |= (1 << index); | end_mask |= (1 << index); | ||||
| } | } | ||||
| strides.Add(slice.Step); | strides.Add(slice.Step); | ||||
| index += 1; | index += 1; | ||||
| } | } | ||||
| return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
| return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||||
| { | { | ||||
| string name = scope; | string name = scope; | ||||
| if (begin != null) | if (begin != null) | ||||
| { | { | ||||
| var (packed_begin, packed_end, packed_strides) = | var (packed_begin, packed_end, packed_strides) = | ||||
| (array_ops.stack(begin.ToArray()), | (array_ops.stack(begin.ToArray()), | ||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| return gen_array_ops.strided_slice( | return gen_array_ops.strided_slice( | ||||
| this, | this, | ||||
| @@ -320,7 +485,6 @@ namespace Tensorflow | |||||
| shrink_axis_mask: shrink_axis_mask, | shrink_axis_mask: shrink_axis_mask, | ||||
| new_axis_mask: new_axis_mask, | new_axis_mask: new_axis_mask, | ||||
| ellipsis_mask: ellipsis_mask, | ellipsis_mask: ellipsis_mask, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| @@ -330,7 +494,7 @@ namespace Tensorflow | |||||
| public Tensor slice(int start) | public Tensor slice(int start) | ||||
| { | { | ||||
| var slice_spec = new int[] { start }; | |||||
| var slice_spec = new int[] {start}; | |||||
| var begin = new List<int>(); | var begin = new List<int>(); | ||||
| var end = new List<int>(); | var end = new List<int>(); | ||||
| var strides = new List<int>(); | var strides = new List<int>(); | ||||
| @@ -349,15 +513,15 @@ namespace Tensorflow | |||||
| index += 1; | index += 1; | ||||
| } | } | ||||
| return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
| return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||||
| { | { | ||||
| string name = scope; | string name = scope; | ||||
| if (begin != null) | if (begin != null) | ||||
| { | { | ||||
| var (packed_begin, packed_end, packed_strides) = | var (packed_begin, packed_end, packed_strides) = | ||||
| (array_ops.stack(begin.ToArray()), | (array_ops.stack(begin.ToArray()), | ||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| return gen_array_ops.strided_slice( | return gen_array_ops.strided_slice( | ||||
| this, | this, | ||||
| @@ -369,7 +533,6 @@ namespace Tensorflow | |||||
| shrink_axis_mask: shrink_axis_mask, | shrink_axis_mask: shrink_axis_mask, | ||||
| new_axis_mask: new_axis_mask, | new_axis_mask: new_axis_mask, | ||||
| ellipsis_mask: ellipsis_mask, | ellipsis_mask: ellipsis_mask, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| @@ -392,29 +555,13 @@ namespace Tensorflow | |||||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | ||||
| } | } | ||||
| protected override void DisposeManagedState() | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | { | ||||
| c_api.TF_DeleteTensor(handle); | |||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| { | |||||
| if(handle != IntPtr.Zero) | |||||
| { | |||||
| c_api.TF_DeleteTensor(handle); | |||||
| } | |||||
| } | |||||
| public bool IsDisposed | |||||
| { | |||||
| get | |||||
| { | |||||
| lock (this) | |||||
| { | |||||
| return _handle == IntPtr.Zero; | |||||
| } | |||||
| } | |||||
| } | |||||
| public bool IsDisposed => _disposed; | |||||
| public int tensor_int_val { get; set; } | public int tensor_int_val { get; set; } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -1,35 +1,84 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Represents the shape of a `Tensor`. | |||||
| /// Represents the shape of a `Tensor`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks> | |||||
| public class TensorShape | public class TensorShape | ||||
| { | { | ||||
| private Shape shape; | |||||
| private readonly Shape shape; | |||||
| /// <summary> | |||||
| /// Returns a list of Dimensions, or None if the shape is unspecified. | |||||
| /// </summary> | |||||
| public int[] dims => shape.Dimensions; | public int[] dims => shape.Dimensions; | ||||
| /// <summary> | |||||
| /// Returns the rank of this shape. | |||||
| /// </summary> | |||||
| public int ndim => shape.NDim; | public int ndim => shape.NDim; | ||||
| /// <summary> | |||||
| /// Returns the rank of this shape. | |||||
| /// </summary> | |||||
| public int rank => shape.NDim; | |||||
| /// <summary> | |||||
| /// Returns the size this shape represents. | |||||
| /// </summary> | |||||
| public int size => shape.Size; | public int size => shape.Size; | ||||
| public TensorShape(TensorShapeProto proto) | public TensorShape(TensorShapeProto proto) | ||||
| { | { | ||||
| if (proto.UnknownRank) return; | if (proto.UnknownRank) return; | ||||
| switch (proto.Dim.Count) | |||||
| { | |||||
| case 0: shape = new Shape(new int[0]); break; | |||||
| case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; | |||||
| case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; | |||||
| default: | |||||
| var protodims = proto.Dim; | |||||
| var len = protodims.Count; | |||||
| var dims = new int[len]; | |||||
| for (int i = 0; i < len; i++) | |||||
| dims[i] = (int) protodims[i].Size; | |||||
| shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); | |||||
| shape = new Shape(dims); break; | |||||
| } | |||||
| } | } | ||||
| public TensorShape(params int[] dims) | public TensorShape(params int[] dims) | ||||
| { | { | ||||
| shape = new Shape(dims); | |||||
| switch (dims.Length) | |||||
| { | |||||
| case 0: shape = new Shape(new int[0]); break; | |||||
| case 1: shape = Shape.Vector((int) dims[0]); break; | |||||
| case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | |||||
| default: shape = new Shape(dims); break; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="slice"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <see cref="Slice"/> is not an Index.</exception> | |||||
| [SuppressMessage("ReSharper", "PossibleInvalidOperationException")] | |||||
| public TensorShape this[Slice slice] | public TensorShape this[Slice slice] | ||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if (slice.Start.HasValue == false || slice.Length.HasValue == false) | |||||
| throw new ArgumentException("Slice must has Start and Length."); | |||||
| return new TensorShape(dims.Skip(slice.Start.Value) | return new TensorShape(dims.Skip(slice.Start.Value) | ||||
| .Take(slice.Length.Value) | .Take(slice.Length.Value) | ||||
| .ToArray()); | .ToArray()); | ||||
| @@ -37,7 +86,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Returns True iff `self` is fully defined in every dimension. | |||||
| /// Returns True iff `self` is fully defined in every dimension. | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public bool is_fully_defined() | public bool is_fully_defined() | ||||
| @@ -50,6 +99,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("TensorShape is_compatible_with"); | throw new NotImplementedException("TensorShape is_compatible_with"); | ||||
| } | } | ||||
| [SuppressMessage("ReSharper", "ParameterHidesMember")] | |||||
| public TensorShape with_rank_at_least(int rank) | public TensorShape with_rank_at_least(int rank) | ||||
| { | { | ||||
| if (rank != ndim) | if (rank != ndim) | ||||
| @@ -59,35 +109,68 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// </summary> | |||||
| /// <param name="other"></param> | |||||
| /// <returns></returns> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public TensorShape concatenate(int[] other) | |||||
| { | |||||
| return concatenate(new TensorShape(other)); | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="other"></param> | /// <param name="other"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public TensorShape concatenate(int[] other_) | |||||
| public TensorShape concatenate(TensorShape other) | |||||
| { | { | ||||
| var other = new TensorShape(other_); | |||||
| var otherShape = other; | |||||
| if (ndim < 0 || other.ndim < 0) | |||||
| if (ndim < 0 || otherShape.ndim < 0) | |||||
| return new TensorShape(); | return new TensorShape(); | ||||
| else | else | ||||
| { | { | ||||
| var concatenate_dims = new int[ndim + other.ndim]; | |||||
| var concatenate_dims = new int[ndim + otherShape.ndim]; | |||||
| for (int i = 0; i < ndim; i++) | for (int i = 0; i < ndim; i++) | ||||
| concatenate_dims[i] = dims[i]; | concatenate_dims[i] = dims[i]; | ||||
| for (int i = 0; i < other.ndim; i++) | |||||
| concatenate_dims[ndim + i] = other.dims[i]; | |||||
| for (int i = 0; i < otherShape.ndim; i++) | |||||
| concatenate_dims[ndim + i] = otherShape.dims[i]; | |||||
| return new TensorShape(concatenate_dims); | return new TensorShape(concatenate_dims); | ||||
| } | } | ||||
| } | } | ||||
| public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); | |||||
| public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); | |||||
| public override string ToString() | |||||
| { | |||||
| return shape.ToString(); | |||||
| } | |||||
| public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); | |||||
| public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | |||||
| public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes | |||||
| public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | ||||
| public static implicit operator int[](TensorShape shape) => shape.dims; | |||||
| public static explicit operator int(TensorShape shape) => shape.size; | |||||
| public static explicit operator TensorShape(int dim) => new TensorShape(dim); | |||||
| public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | |||||
| public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | ||||
| public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | ||||
| public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | ||||
| public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); | |||||
| public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Numerics; | |||||
| using NumSharp.Backends; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -23,35 +25,100 @@ namespace Tensorflow | |||||
| public static TF_DataType int8 = TF_DataType.TF_INT8; | public static TF_DataType int8 = TF_DataType.TF_INT8; | ||||
| public static TF_DataType int32 = TF_DataType.TF_INT32; | public static TF_DataType int32 = TF_DataType.TF_INT32; | ||||
| public static TF_DataType int64 = TF_DataType.TF_INT64; | public static TF_DataType int64 = TF_DataType.TF_INT64; | ||||
| public static TF_DataType uint8 = TF_DataType.TF_UINT8; | |||||
| public static TF_DataType uint32 = TF_DataType.TF_UINT32; | |||||
| public static TF_DataType uint64 = TF_DataType.TF_UINT64; | |||||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | ||||
| public static TF_DataType float16 = TF_DataType.TF_HALF; | public static TF_DataType float16 = TF_DataType.TF_HALF; | ||||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | ||||
| public static Type as_numpy_datatype(this TF_DataType type) | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns> | |||||
| public static Type as_numpy_dtype(this TF_DataType type) | |||||
| { | { | ||||
| switch (type) | switch (type) | ||||
| { | { | ||||
| case TF_DataType.TF_BOOL: | case TF_DataType.TF_BOOL: | ||||
| return typeof(bool); | return typeof(bool); | ||||
| case TF_DataType.TF_UINT8: | |||||
| return typeof(byte); | |||||
| case TF_DataType.TF_INT64: | case TF_DataType.TF_INT64: | ||||
| return typeof(long); | return typeof(long); | ||||
| case TF_DataType.TF_UINT64: | |||||
| return typeof(ulong); | |||||
| case TF_DataType.TF_INT32: | case TF_DataType.TF_INT32: | ||||
| return typeof(int); | return typeof(int); | ||||
| case TF_DataType.TF_UINT32: | |||||
| return typeof(uint); | |||||
| case TF_DataType.TF_INT16: | case TF_DataType.TF_INT16: | ||||
| return typeof(short); | return typeof(short); | ||||
| case TF_DataType.TF_UINT16: | |||||
| return typeof(ushort); | |||||
| case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
| return typeof(float); | return typeof(float); | ||||
| case TF_DataType.TF_DOUBLE: | case TF_DataType.TF_DOUBLE: | ||||
| return typeof(double); | return typeof(double); | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| return typeof(string); | return typeof(string); | ||||
| case TF_DataType.TF_COMPLEX128: | |||||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||||
| return typeof(Complex); | |||||
| default: | default: | ||||
| return null; | return null; | ||||
| } | } | ||||
| } | } | ||||
| // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" | |||||
| public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="NPTypeCode"/></exception> | |||||
| public static NPTypeCode as_numpy_typecode(this TF_DataType type) | |||||
| { | |||||
| switch (type) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| return NPTypeCode.Boolean; | |||||
| case TF_DataType.TF_UINT8: | |||||
| return NPTypeCode.Byte; | |||||
| case TF_DataType.TF_INT64: | |||||
| return NPTypeCode.Int64; | |||||
| case TF_DataType.TF_INT32: | |||||
| return NPTypeCode.Int32; | |||||
| case TF_DataType.TF_INT16: | |||||
| return NPTypeCode.Int16; | |||||
| case TF_DataType.TF_UINT64: | |||||
| return NPTypeCode.UInt64; | |||||
| case TF_DataType.TF_UINT32: | |||||
| return NPTypeCode.UInt32; | |||||
| case TF_DataType.TF_UINT16: | |||||
| return NPTypeCode.UInt16; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| return NPTypeCode.Single; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| return NPTypeCode.Double; | |||||
| case TF_DataType.TF_STRING: | |||||
| return NPTypeCode.String; | |||||
| case TF_DataType.TF_COMPLEX128: | |||||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||||
| return NPTypeCode.Complex; | |||||
| default: | |||||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | |||||
| public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null) | |||||
| { | { | ||||
| switch (type.Name) | switch (type.Name) | ||||
| { | { | ||||
| @@ -98,7 +165,7 @@ namespace Tensorflow | |||||
| dtype = TF_DataType.TF_BOOL; | dtype = TF_DataType.TF_BOOL; | ||||
| break; | break; | ||||
| default: | default: | ||||
| throw new Exception("as_dtype Not Implemented"); | |||||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||||
| } | } | ||||
| return dtype.Value; | return dtype.Value; | ||||
| @@ -106,16 +173,7 @@ namespace Tensorflow | |||||
| public static DataType as_datatype_enum(this TF_DataType type) | public static DataType as_datatype_enum(this TF_DataType type) | ||||
| { | { | ||||
| DataType dtype = DataType.DtInvalid; | |||||
| switch (type) | |||||
| { | |||||
| default: | |||||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||||
| break; | |||||
| } | |||||
| return dtype; | |||||
| return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; | |||||
| } | } | ||||
| public static TF_DataType as_base_dtype(this TF_DataType type) | public static TF_DataType as_base_dtype(this TF_DataType type) | ||||
| @@ -132,7 +190,7 @@ namespace Tensorflow | |||||
| public static Type as_numpy_dtype(this DataType type) | public static Type as_numpy_dtype(this DataType type) | ||||
| { | { | ||||
| return type.as_tf_dtype().as_numpy_datatype(); | |||||
| return type.as_tf_dtype().as_numpy_dtype(); | |||||
| } | } | ||||
| public static DataType as_base_dtype(this DataType type) | public static DataType as_base_dtype(this DataType type) | ||||
| @@ -144,16 +202,7 @@ namespace Tensorflow | |||||
| public static TF_DataType as_tf_dtype(this DataType type) | public static TF_DataType as_tf_dtype(this DataType type) | ||||
| { | { | ||||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||||
| switch (type) | |||||
| { | |||||
| default: | |||||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||||
| break; | |||||
| } | |||||
| return dtype; | |||||
| return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; | |||||
| } | } | ||||
| public static TF_DataType as_ref(this TF_DataType type) | public static TF_DataType as_ref(this TF_DataType type) | ||||
| @@ -17,6 +17,7 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using NumSharp.Utilities; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -82,6 +83,12 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("MakeNdarray"); | throw new NotImplementedException("MakeNdarray"); | ||||
| } | } | ||||
| private static readonly TF_DataType[] quantized_types = new TF_DataType[] | |||||
| { | |||||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||||
| TF_DataType.TF_QINT32 | |||||
| }; | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a TensorProto. | /// Create a TensorProto. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -98,18 +105,9 @@ namespace Tensorflow | |||||
| if (values is TensorProto tp) | if (values is TensorProto tp) | ||||
| return tp; | return tp; | ||||
| if (dtype != TF_DataType.DtInvalid) | |||||
| ; | |||||
| bool is_quantized = new TF_DataType[] | |||||
| { | |||||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||||
| TF_DataType.TF_QINT32 | |||||
| }.Contains(dtype); | |||||
| // We first convert value to a numpy array or scalar. | // We first convert value to a numpy array or scalar. | ||||
| NDArray nparray = null; | NDArray nparray = null; | ||||
| var np_dt = dtype.as_numpy_datatype(); | |||||
| var np_dt = dtype.as_numpy_dtype(); | |||||
| if (values is NDArray nd) | if (values is NDArray nd) | ||||
| { | { | ||||
| @@ -188,37 +186,37 @@ namespace Tensorflow | |||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((int[])values, np_dt); | nparray = np.array((int[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToInt32(values); | |||||
| nparray = Converts.ToInt32(values); | |||||
| break; | break; | ||||
| case "Int64": | case "Int64": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((int[])values, np_dt); | nparray = np.array((int[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToInt64(values); | |||||
| nparray = Converts.ToInt64(values); | |||||
| break; | break; | ||||
| case "Single": | case "Single": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((float[])values, np_dt); | nparray = np.array((float[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToSingle(values); | |||||
| nparray = Converts.ToSingle(values); | |||||
| break; | break; | ||||
| case "Double": | case "Double": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((double[])values, np_dt); | nparray = np.array((double[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToDouble(values); | |||||
| nparray = Converts.ToDouble(values); | |||||
| break; | break; | ||||
| case "String": | case "String": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((string[])values, np_dt); | nparray = np.array((string[])values, np_dt); | ||||
| else | else | ||||
| nparray = NDArray.FromString(Convert.ToString(values)); | |||||
| nparray = NDArray.FromString(Converts.ToString(values)); | |||||
| break; | break; | ||||
| case "Boolean": | case "Boolean": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((bool[])values, np_dt); | nparray = np.array((bool[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToBoolean(values); | |||||
| nparray = Converts.ToBoolean(values); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); | throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); | ||||
| @@ -226,13 +224,13 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype); | |||||
| var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); | |||||
| if (numpy_dtype == TF_DataType.DtInvalid) | if (numpy_dtype == TF_DataType.DtInvalid) | ||||
| throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | ||||
| // If dtype was specified and is a quantized type, we convert | // If dtype was specified and is a quantized type, we convert | ||||
| // numpy_dtype back into the quantized version. | // numpy_dtype back into the quantized version. | ||||
| if (is_quantized) | |||||
| if (quantized_types.Contains(dtype)) | |||||
| numpy_dtype = dtype; | numpy_dtype = dtype; | ||||
| bool is_same_size = false; | bool is_same_size = false; | ||||
| @@ -0,0 +1,52 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Train | |||||
| { | |||||
| public class ExponentialMovingAverage | |||||
| { | |||||
| float _decay; | |||||
| int? _num_updates; | |||||
| bool _zero_debias; | |||||
| string _name; | |||||
| public string name => _name; | |||||
| List<VariableV1> _averages; | |||||
| public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, | |||||
| string name = "ExponentialMovingAverage") | |||||
| { | |||||
| _decay = decay; | |||||
| _num_updates = num_updates; | |||||
| _zero_debias = zero_debias; | |||||
| _name = name; | |||||
| _averages = new List<VariableV1>(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Maintains moving averages of variables. | |||||
| /// </summary> | |||||
| /// <param name="var_list"></param> | |||||
| /// <returns></returns> | |||||
| public Operation apply(RefVariable[] var_list = null) | |||||
| { | |||||
| if (var_list == null) | |||||
| var_list = variables.trainable_variables() as RefVariable[]; | |||||
| foreach(var var in var_list) | |||||
| { | |||||
| if (!_averages.Contains(var)) | |||||
| { | |||||
| ops.init_scope(); | |||||
| var slot = new SlotCreator(); | |||||
| var.initialized_value(); | |||||
| // var avg = slot.create_zeros_slot | |||||
| } | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -198,7 +198,7 @@ namespace Tensorflow | |||||
| if (!tf.context.executing_eagerly()) | if (!tf.context.executing_eagerly()) | ||||
| { | { | ||||
| var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||||
| var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||||
| if (train_op != null && train_op.Contains(apply_updates)) | if (train_op != null && train_op.Contains(apply_updates)) | ||||
| train_op.Add(apply_updates); | train_op.Add(apply_updates); | ||||
| } | } | ||||
| @@ -359,7 +359,7 @@ namespace Tensorflow | |||||
| var tmp = variables.trainable_variables(); | var tmp = variables.trainable_variables(); | ||||
| var vars = ops.get_collection<RefVariable>(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||||
| var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||||
| switch (tmp) | switch (tmp) | ||||
| { | { | ||||
| case List<RefVariable> values: | case List<RefVariable> values: | ||||
| @@ -370,7 +370,7 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| } | } | ||||
| var_list = var_list.Concat(ops.get_collection<RefVariable>(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||||
| var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||||
| var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | ||||
| var var_refs = processors.Select(x => x.target()).ToArray(); | var var_refs = processors.Select(x => x.target()).ToArray(); | ||||
| @@ -0,0 +1,94 @@ | |||||
| using System; | |||||
| using System.IO; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | |||||
| using NumSharp.Backends.Unmanaged; | |||||
| namespace Tensorflow.Util | |||||
| { | |||||
| public static class UnmanagedExtensions | |||||
| { | |||||
| //internally UnmanagedMemoryStream can't construct with null address. | |||||
| private static readonly unsafe byte* _empty = (byte*) Marshal.AllocHGlobal(1); | |||||
| /// <summary> | |||||
| /// Creates a memory stream based on given <paramref name="block"/>. | |||||
| /// </summary> | |||||
| /// <param name="block">The block to stream. Can be default/null.</param> | |||||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| if (block.Address == null) | |||||
| return new UnmanagedMemoryStream(_empty, 0); | |||||
| return new UnmanagedMemoryStream(block.Address, block.BytesCount); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a memory stream based on given <paramref name="block"/>. | |||||
| /// </summary> | |||||
| /// <param name="block">The block to stream. Can be default/null.</param> | |||||
| /// <param name="offset">Offset from the start of the block.</param> | |||||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block, long offset) | |||||
| { | |||||
| if (block.BytesCount - offset <= 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(offset)); | |||||
| unsafe | |||||
| { | |||||
| if (block.Address == null) | |||||
| return new UnmanagedMemoryStream(_empty, 0); | |||||
| return new UnmanagedMemoryStream(block.Address + offset, block.BytesCount - offset); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a memory stream based on given <paramref name="address"/>. | |||||
| /// </summary> | |||||
| /// <param name="address">The block to stream. Can be IntPtr.Zero.</param> | |||||
| /// <param name="length">The length of the block in bytes.</param> | |||||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static UnmanagedMemoryStream Stream(this IntPtr address, long length) | |||||
| { | |||||
| if (length <= 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||||
| unsafe | |||||
| { | |||||
| if (address == IntPtr.Zero) | |||||
| return new UnmanagedMemoryStream(_empty, 0); | |||||
| // ReSharper disable once AssignNullToNotNullAttribute | |||||
| return new UnmanagedMemoryStream((byte*) address, length); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a memory stream based on given <paramref name="address"/>. | |||||
| /// </summary> | |||||
| /// <param name="address">The block to stream. Can be IntPtr.Zero.</param> | |||||
| /// <param name="offset">Offset from the start of the block.</param> | |||||
| /// <param name="length">The length of the block in bytes.</param> | |||||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static UnmanagedMemoryStream Stream(this IntPtr address, long offset, long length) | |||||
| { | |||||
| if (length <= 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||||
| unsafe | |||||
| { | |||||
| if (address == IntPtr.Zero) | |||||
| return new UnmanagedMemoryStream(_empty, 0); | |||||
| return new UnmanagedMemoryStream((byte*) address + offset, length); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -121,7 +121,7 @@ namespace Tensorflow | |||||
| if(collections == null) | if(collections == null) | ||||
| { | { | ||||
| collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES }; | |||||
| collections = new List<string> { tf.GraphKeys.GLOBAL_VARIABLES }; | |||||
| } | } | ||||
| // Store the graph key so optimizers know how to only retrieve variables from | // Store the graph key so optimizers know how to only retrieve variables from | ||||
| @@ -129,8 +129,8 @@ namespace Tensorflow | |||||
| _graph_key = ops.get_default_graph().graph_key; | _graph_key = ops.get_default_graph().graph_key; | ||||
| _trainable = trainable; | _trainable = trainable; | ||||
| if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) | |||||
| collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); | |||||
| if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | |||||
| collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | |||||
| ops.init_scope(); | ops.init_scope(); | ||||
| var values = init_from_fn ? new object[0] : new object[] { initial_value }; | var values = init_from_fn ? new object[0] : new object[] { initial_value }; | ||||
| @@ -158,7 +158,7 @@ namespace Tensorflow | |||||
| // Or get the initial value from a Tensor or Python object. | // Or get the initial value from a Tensor or Python object. | ||||
| else | else | ||||
| { | { | ||||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); | |||||
| var shape = _initial_value.shape; | var shape = _initial_value.shape; | ||||
| dtype = _initial_value.dtype; | dtype = _initial_value.dtype; | ||||
| @@ -308,5 +308,28 @@ namespace Tensorflow | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns the value of this variable, read in the current context. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| private ITensorOrOperation read_value() | |||||
| { | |||||
| return array_ops.identity(_variable, name: "read"); | |||||
| } | |||||
| public Tensor is_variable_initialized(RefVariable variable) | |||||
| { | |||||
| return state_ops.is_variable_initialized(variable); | |||||
| } | |||||
| public Tensor initialized_value() | |||||
| { | |||||
| ops.init_scope(); | |||||
| throw new NotImplementedException(""); | |||||
| /*return control_flow_ops.cond(is_variable_initialized(this), | |||||
| read_value, | |||||
| () => initial_value);*/ | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| @@ -145,5 +146,10 @@ namespace Tensorflow | |||||
| var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); | var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor is_variable_initialized(RefVariable @ref, string name = null) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -106,5 +106,13 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("scatter_add"); | throw new NotImplementedException("scatter_add"); | ||||
| } | } | ||||
| public static Tensor is_variable_initialized(RefVariable @ref, string name = null) | |||||
| { | |||||
| if (@ref.dtype.is_ref_dtype()) | |||||
| return gen_state_ops.is_variable_initialized(@ref: @ref, name: name); | |||||
| throw new NotImplementedException(""); | |||||
| //return @ref.is_initialized(name: name); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -28,7 +29,7 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static object trainable_variables() | public static object trainable_variables() | ||||
| { | { | ||||
| return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||||
| return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -40,11 +41,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| var all = new List<VariableV1>(); | var all = new List<VariableV1>(); | ||||
| var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| if(collection != null) | if(collection != null) | ||||
| all.AddRange(collection as List<VariableV1>); | all.AddRange(collection as List<VariableV1>); | ||||
| collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); | |||||
| collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); | |||||
| if (collection != null) | if (collection != null) | ||||
| all.AddRange(collection as List<VariableV1>); | all.AddRange(collection as List<VariableV1>); | ||||
| @@ -64,7 +65,7 @@ namespace Tensorflow | |||||
| /// <returns>A list of `Variable` objects.</returns> | /// <returns>A list of `Variable` objects.</returns> | ||||
| public static List<VariableV1> global_variables(string scope = null) | public static List<VariableV1> global_variables(string scope = null) | ||||
| { | { | ||||
| var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| return result == null ? new List<VariableV1>() : result as List<VariableV1>; | return result == null ? new List<VariableV1>() : result as List<VariableV1>; | ||||
| } | } | ||||
| @@ -0,0 +1,40 @@ | |||||
| %all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] | |||||
| %all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] | |||||
| %supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] | |||||
| %supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] | |||||
| %supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||||
| %supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||||
| %supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||||
| %supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] | |||||
| %supported_numericals_TF_DataType = ["TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] | |||||
| %supported_numericals_TF_DataType_full = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||||
| //this is the type we use in summerizing/reducting: | |||||
| %supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||||
| %supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||||
| %supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"] | |||||
| %supported_numericals_signed_lowercase = ["short","int","long","double","float"] | |||||
| %supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"] | |||||
| %supported_numericals_signed_onevales = ["1","1","1L","1d","1f"] | |||||
| %supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"] | |||||
| %supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"] | |||||
| %supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"] | |||||
| %supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] | |||||
| %supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||||
| %supported_dtypes_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] | |||||
| %supported_dtypes_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||||
| %supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||||
| %supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||||
| %supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"] | |||||
| %supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"] | |||||
| //this is the type we use in summerizing/reducting: | |||||
| %supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||||
| %supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||||
| @@ -27,57 +27,113 @@ namespace Tensorflow | |||||
| /// specified, but it is also possible to pass an explicit list of | /// specified, but it is also possible to pass an explicit list of | ||||
| /// variables. | /// variables. | ||||
| /// </summary> | /// </summary> | ||||
| public static class GraphKeys | |||||
| public class GraphKeys | |||||
| { | { | ||||
| #region const | |||||
| /// <summary> | |||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | |||||
| /// </summary> | |||||
| public const string TRAINABLE_VARIABLES_ = "trainable_variables"; | |||||
| /// <summary> | |||||
| /// Trainable resource-style variables. | |||||
| /// </summary> | |||||
| public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; | |||||
| /// <summary> | |||||
| /// Key for streaming model ports. | |||||
| /// </summary> | |||||
| public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; | |||||
| /// <summary> | |||||
| /// Key to collect losses | |||||
| /// </summary> | |||||
| public const string LOSSES_ = "losses"; | |||||
| /// <summary> | |||||
| /// Key to collect Variable objects that are global (shared across machines). | |||||
| /// Default collection for all variables, except local ones. | |||||
| /// </summary> | |||||
| public const string GLOBAL_VARIABLES_ = "variables"; | |||||
| public const string TRAIN_OP_ = "train_op"; | |||||
| public const string GLOBAL_STEP_ = "global_step"; | |||||
| public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| /// <summary> | |||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||||
| /// </summary> | |||||
| public const string SAVEABLE_OBJECTS_ = "saveable_objects"; | |||||
| /// <summary> | |||||
| /// Key to collect update_ops | |||||
| /// </summary> | |||||
| public const string UPDATE_OPS_ = "update_ops"; | |||||
| // Key to collect summaries. | |||||
| public const string SUMMARIES_ = "summaries"; | |||||
| // Used to store v2 summary names. | |||||
| public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; | |||||
| // Key for control flow context. | |||||
| public const string COND_CONTEXT_ = "cond_context"; | |||||
| public const string WHILE_CONTEXT_ = "while_context"; | |||||
| #endregion | |||||
| /// <summary> | /// <summary> | ||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | /// the subset of `Variable` objects that will be trained by an optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| public static string TRAINABLE_VARIABLES = "trainable_variables"; | |||||
| public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Trainable resource-style variables. | /// Trainable resource-style variables. | ||||
| /// </summary> | /// </summary> | ||||
| public static string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||||
| public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key for streaming model ports. | /// Key for streaming model ports. | ||||
| /// </summary> | /// </summary> | ||||
| public static string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||||
| public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect losses | /// Key to collect losses | ||||
| /// </summary> | /// </summary> | ||||
| public const string LOSSES = "losses"; | |||||
| public string LOSSES => LOSSES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
| /// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
| /// </summary> | /// </summary> | ||||
| public static string GLOBAL_VARIABLES = "variables"; | |||||
| public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; | |||||
| public static string TRAIN_OP = "train_op"; | |||||
| public string TRAIN_OP => TRAIN_OP_; | |||||
| public static string GLOBAL_STEP = GLOBAL_STEP = "global_step"; | |||||
| public string GLOBAL_STEP => GLOBAL_STEP_; | |||||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | ||||
| /// </summary> | /// </summary> | ||||
| public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||||
| public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect update_ops | /// Key to collect update_ops | ||||
| /// </summary> | /// </summary> | ||||
| public static string UPDATE_OPS = "update_ops"; | |||||
| public string UPDATE_OPS => UPDATE_OPS_; | |||||
| // Key to collect summaries. | // Key to collect summaries. | ||||
| public const string SUMMARIES = "summaries"; | |||||
| public string SUMMARIES => SUMMARIES_; | |||||
| // Used to store v2 summary names. | // Used to store v2 summary names. | ||||
| public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||||
| public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; | |||||
| // Key for control flow context. | // Key for control flow context. | ||||
| public static string COND_CONTEXT = "cond_context"; | |||||
| public static string WHILE_CONTEXT = "while_context"; | |||||
| public string COND_CONTEXT => COND_CONTEXT_; | |||||
| public string WHILE_CONTEXT => WHILE_CONTEXT_; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -230,8 +230,8 @@ namespace Tensorflow | |||||
| // Add attrs | // Add attrs | ||||
| foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
| { | { | ||||
| var bytes = attr.Value.ToByteArray(); | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| uint len = (uint)bytes.Length; | uint len = (uint)bytes.Length; | ||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | ||||
| @@ -488,6 +488,8 @@ namespace Tensorflow | |||||
| switch (value) | switch (value) | ||||
| { | { | ||||
| case String str: | |||||
| return constant_op.constant(str, dtype: TF_DataType.TF_STRING, name: name); | |||||
| case NDArray nd: | case NDArray nd: | ||||
| return constant_op.constant(nd, dtype: dtype, name: name); | return constant_op.constant(nd, dtype: dtype, name: name); | ||||
| case Tensor tensor: | case Tensor tensor: | ||||
| @@ -64,13 +64,12 @@ namespace Tensorflow | |||||
| public Session Session() | public Session Session() | ||||
| { | { | ||||
| defaultSession = new Session(); | |||||
| return defaultSession; | |||||
| return new Session(); | |||||
| } | } | ||||
| public Session Session(Graph graph) | |||||
| public Session Session(Graph graph, SessionOptions opts = null) | |||||
| { | { | ||||
| return new Session(graph); | |||||
| return new Session(graph, opts: opts); | |||||
| } | } | ||||
| public Session Session(SessionOptions opts) | public Session Session(SessionOptions opts) | ||||
| @@ -9,24 +9,18 @@ namespace TensorFlowBenchmark | |||||
| { | { | ||||
| static void Main(string[] args) | static void Main(string[] args) | ||||
| { | { | ||||
| #if DEBUG | |||||
| IConfig config = new DebugInProcessConfig(); | |||||
| #else | |||||
| IConfig config = null; | |||||
| #endif | |||||
| if (args?.Length > 0) | if (args?.Length > 0) | ||||
| { | { | ||||
| for (int i = 0; i < args.Length; i++) | for (int i = 0; i < args.Length; i++) | ||||
| { | { | ||||
| string name = $"TensorFlowBenchmark.{args[i]}"; | string name = $"TensorFlowBenchmark.{args[i]}"; | ||||
| var type = Type.GetType(name); | var type = Type.GetType(name); | ||||
| BenchmarkRunner.Run(type, config); | |||||
| BenchmarkRunner.Run(type); | |||||
| } | } | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, config); | |||||
| BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator)); | |||||
| } | } | ||||
| Console.ReadLine(); | Console.ReadLine(); | ||||
| @@ -6,6 +6,7 @@ | |||||
| <NoWin32Manifest>true</NoWin32Manifest> | <NoWin32Manifest>true</NoWin32Manifest> | ||||
| <AssemblyName>TensorFlowBenchmark</AssemblyName> | <AssemblyName>TensorFlowBenchmark</AssemblyName> | ||||
| <RootNamespace>TensorFlowBenchmark</RootNamespace> | <RootNamespace>TensorFlowBenchmark</RootNamespace> | ||||
| <LangVersion>7.3</LangVersion> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
| @@ -0,0 +1,76 @@ | |||||
| using System; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | |||||
| using BenchmarkDotNet.Attributes; | |||||
| using Google.Protobuf.WellKnownTypes; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowBenchmark.Unmanaged | |||||
| { | |||||
| public struct UnmanagedStruct | |||||
| { | |||||
| public int a; | |||||
| public long b; | |||||
| public UnmanagedStruct(int _) | |||||
| { | |||||
| a = 2; | |||||
| b = 3; | |||||
| } | |||||
| } | |||||
| [SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)] | |||||
| [MinColumn, MaxColumn, MeanColumn, MedianColumn] | |||||
| public unsafe class StructCastBenchmark | |||||
| { | |||||
| private static void EnsureIsUnmanaged<T>(T _) where T : unmanaged | |||||
| { } | |||||
| static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile. | |||||
| => EnsureIsUnmanaged(new UnmanagedStruct()); | |||||
| private IntPtr data; | |||||
| private void* dataptr; | |||||
| [GlobalSetup] | |||||
| public void Setup() | |||||
| { | |||||
| data = Marshal.AllocHGlobal(Marshal.SizeOf<UnmanagedStruct>()); | |||||
| dataptr = data.ToPointer(); | |||||
| } | |||||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||||
| public void Marshal_PtrToStructure() | |||||
| { | |||||
| UnmanagedStruct _; | |||||
| for (int i = 0; i < 10000; i++) | |||||
| { | |||||
| _ = Marshal.PtrToStructure<UnmanagedStruct>(data); | |||||
| } | |||||
| } | |||||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||||
| public void PointerCast() | |||||
| { | |||||
| var dptr = dataptr; | |||||
| UnmanagedStruct _; | |||||
| for (int i = 0; i < 10000; i++) | |||||
| { | |||||
| _ = *(UnmanagedStruct*) dptr; | |||||
| } | |||||
| } | |||||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||||
| public void Unsafe_Read() | |||||
| { | |||||
| var dptr = dataptr; | |||||
| UnmanagedStruct _; | |||||
| for (int i = 0; i < 10000; i++) | |||||
| { | |||||
| _ = Unsafe.Read<UnmanagedStruct>(dptr); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -80,26 +80,18 @@ namespace TensorFlowNET.Examples | |||||
| for (int epoch = 0; epoch < training_epochs; epoch++) | for (int epoch = 0; epoch < training_epochs; epoch++) | ||||
| { | { | ||||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | foreach (var (x, y) in zip<float>(train_X, train_Y)) | ||||
| { | |||||
| sess.run(optimizer, | |||||
| new FeedItem(X, x), | |||||
| new FeedItem(Y, y)); | |||||
| } | |||||
| sess.run(optimizer, (X, x), (Y, y)); | |||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| if ((epoch + 1) % display_step == 0) | if ((epoch + 1) % display_step == 0) | ||||
| { | { | ||||
| var c = sess.run(cost, | |||||
| new FeedItem(X, train_X), | |||||
| new FeedItem(Y, train_Y)); | |||||
| var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | ||||
| } | } | ||||
| } | } | ||||
| Console.WriteLine("Optimization Finished!"); | Console.WriteLine("Optimization Finished!"); | ||||
| var training_cost = sess.run(cost, | |||||
| new FeedItem(X, train_X), | |||||
| new FeedItem(Y, train_Y)); | |||||
| var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | ||||
| // Testing example | // Testing example | ||||
| @@ -107,8 +99,7 @@ namespace TensorFlowNET.Examples | |||||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | ||||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | Console.WriteLine("Testing... (Mean square loss Comparison)"); | ||||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | ||||
| new FeedItem(X, test_X), | |||||
| new FeedItem(Y, test_Y)); | |||||
| (X, test_X), (Y, test_Y)); | |||||
| Console.WriteLine($"Testing cost={testing_cost}"); | Console.WriteLine($"Testing cost={testing_cost}"); | ||||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | var diff = Math.Abs((float)training_cost - (float)testing_cost); | ||||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | Console.WriteLine($"Absolute mean square loss difference: {diff}"); | ||||
| @@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples | |||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| if ((epoch + 1) % display_step == 0) | if ((epoch + 1) % display_step == 0) | ||||
| print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); | |||||
| print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); | |||||
| sw.Reset(); | sw.Reset(); | ||||
| } | } | ||||
| @@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples | |||||
| var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | ||||
| // Calculate accuracy | // Calculate accuracy | ||||
| var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | ||||
| float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||||
| print($"Accuracy: {acc.ToString("F4")}"); | |||||
| float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||||
| print($"Accuracy: {acc:F4}"); | |||||
| return acc > 0.9; | return acc > 0.9; | ||||
| } | } | ||||
| @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| // get model file | // get model file | ||||
| string url = "http://download.tf.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; | |||||
| string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; | |||||
| Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); | Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); | ||||
| Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); | Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); | ||||
| @@ -21,6 +21,7 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Threading.Tasks; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using TensorFlowNET.Examples.Utility; | using TensorFlowNET.Examples.Utility; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -381,10 +382,15 @@ namespace TensorFlowNET.Examples | |||||
| Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) | Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) | ||||
| { | { | ||||
| int how_many_bottlenecks = 0; | int how_many_bottlenecks = 0; | ||||
| foreach (var (label_name, label_lists) in image_lists) | |||||
| var kvs = image_lists.ToArray(); | |||||
| var categories = new string[] {"training", "testing", "validation"}; | |||||
| Parallel.For(0, kvs.Length, i => | |||||
| { | { | ||||
| foreach (var category in new string[] { "training", "testing", "validation" }) | |||||
| var (label_name, label_lists) = kvs[i]; | |||||
| Parallel.For(0, categories.Length, j => | |||||
| { | { | ||||
| var category = categories[j]; | |||||
| var category_list = label_lists[category]; | var category_list = label_lists[category]; | ||||
| foreach (var (index, unused_base_name) in enumerate(category_list)) | foreach (var (index, unused_base_name) in enumerate(category_list)) | ||||
| { | { | ||||
| @@ -395,8 +401,8 @@ namespace TensorFlowNET.Examples | |||||
| if (how_many_bottlenecks % 300 == 0) | if (how_many_bottlenecks % 300 == 0) | ||||
| print($"{how_many_bottlenecks} bottleneck files created."); | print($"{how_many_bottlenecks} bottleneck files created."); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| }); | |||||
| }); | |||||
| } | } | ||||
| private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | ||||
| @@ -508,7 +514,7 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| // get a set of images to teach the network about the new classes | // get a set of images to teach the network about the new classes | ||||
| string fileName = "flower_photos.tgz"; | string fileName = "flower_photos.tgz"; | ||||
| string url = $"http://download.tf.org/example_images/{fileName}"; | |||||
| string url = $"http://download.tensorflow.org/example_images/{fileName}"; | |||||
| Web.Download(url, data_dir, fileName); | Web.Download(url, data_dir, fileName); | ||||
| Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); | Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); | ||||
| @@ -0,0 +1,54 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| { | |||||
| public class Dataset | |||||
| { | |||||
| string annot_path; | |||||
| int[] input_sizes; | |||||
| int batch_size; | |||||
| bool data_aug; | |||||
| int[] train_input_sizes; | |||||
| NDArray strides; | |||||
| NDArray anchors; | |||||
| Dictionary<int, string> classes; | |||||
| int num_classes; | |||||
| int anchor_per_scale; | |||||
| int max_bbox_per_scale; | |||||
| string[] annotations; | |||||
| int num_samples; | |||||
| int batch_count; | |||||
| public int Length = 0; | |||||
| public Dataset(string dataset_type, Config cfg) | |||||
| { | |||||
| annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH; | |||||
| input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE; | |||||
| batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE; | |||||
| data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG; | |||||
| train_input_sizes = cfg.TRAIN.INPUT_SIZE; | |||||
| strides = np.array(cfg.YOLO.STRIDES); | |||||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||||
| num_classes = classes.Count; | |||||
| anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS)); | |||||
| anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; | |||||
| max_bbox_per_scale = 150; | |||||
| annotations = load_annotations(); | |||||
| num_samples = len(annotations); | |||||
| batch_count = 0; | |||||
| } | |||||
| string[] load_annotations() | |||||
| { | |||||
| return File.ReadAllLines(annot_path); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,139 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| { | |||||
| /// <summary> | |||||
| /// Implementation of YOLO v3 object detector in Tensorflow | |||||
| /// https://github.com/YunYang1994/tensorflow-yolov3 | |||||
| /// </summary> | |||||
| public class Main : IExample | |||||
| { | |||||
| public bool Enabled { get; set; } = true; | |||||
| public bool IsImportingGraph { get; set; } = false; | |||||
| public string Name => "YOLOv3"; | |||||
| #region args | |||||
| Dictionary<int, string> classes; | |||||
| int num_classes; | |||||
| float learn_rate_init; | |||||
| float learn_rate_end; | |||||
| int first_stage_epochs; | |||||
| int second_stage_epochs; | |||||
| int warmup_periods; | |||||
| string time; | |||||
| float moving_ave_decay; | |||||
| int max_bbox_per_scale; | |||||
| int steps_per_period; | |||||
| Dataset trainset, testset; | |||||
| Config cfg; | |||||
| Tensor input_data; | |||||
| Tensor label_sbbox; | |||||
| Tensor label_mbbox; | |||||
| Tensor label_lbbox; | |||||
| Tensor true_sbboxes; | |||||
| Tensor true_mbboxes; | |||||
| Tensor true_lbboxes; | |||||
| Tensor trainable; | |||||
| Session sess; | |||||
| YOLOv3 model; | |||||
| #endregion | |||||
| public bool Run() | |||||
| { | |||||
| PrepareData(); | |||||
| var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | |||||
| var options = new SessionOptions(); | |||||
| options.SetConfig(new ConfigProto { AllowSoftPlacement = true }); | |||||
| using (var sess = tf.Session(graph, opts: options)) | |||||
| { | |||||
| Train(sess); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| public void Train(Session sess) | |||||
| { | |||||
| } | |||||
| public void Test(Session sess) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public Graph BuildGraph() | |||||
| { | |||||
| var graph = new Graph().as_default(); | |||||
| tf_with(tf.name_scope("define_input"), scope => | |||||
| { | |||||
| input_data = tf.placeholder(dtype: tf.float32, name: "input_data"); | |||||
| label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox"); | |||||
| label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox"); | |||||
| label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox"); | |||||
| true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes"); | |||||
| true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes"); | |||||
| true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes"); | |||||
| trainable = tf.placeholder(dtype: tf.@bool, name: "training"); | |||||
| }); | |||||
| tf_with(tf.name_scope("define_loss"), scope => | |||||
| { | |||||
| model = new YOLOv3(cfg, input_data, trainable); | |||||
| }); | |||||
| tf_with(tf.name_scope("define_weight_decay"), scope => | |||||
| { | |||||
| var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables()); | |||||
| }); | |||||
| return graph; | |||||
| } | |||||
| public Graph ImportGraph() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void Predict(Session sess) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void PrepareData() | |||||
| { | |||||
| cfg = new Config(Name); | |||||
| string dataDir = Path.Combine(Name, "data"); | |||||
| Directory.CreateDirectory(dataDir); | |||||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||||
| num_classes = classes.Count; | |||||
| learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT; | |||||
| learn_rate_end = cfg.TRAIN.LEARN_RATE_END; | |||||
| first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS; | |||||
| second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS; | |||||
| warmup_periods = cfg.TRAIN.WARMUP_EPOCHS; | |||||
| DateTime now = DateTime.Now; | |||||
| time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}"; | |||||
| moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY; | |||||
| max_bbox_per_scale = 150; | |||||
| trainset = new Dataset("train", cfg); | |||||
| testset = new Dataset("test", cfg); | |||||
| steps_per_period = trainset.Length; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,27 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| { | |||||
| class Utils | |||||
| { | |||||
| public static Dictionary<int, string> read_class_names(string file) | |||||
| { | |||||
| var classes = new Dictionary<int, string>(); | |||||
| foreach (var line in File.ReadAllLines(file)) | |||||
| classes[classes.Count] = line; | |||||
| return classes; | |||||
| } | |||||
| public static NDArray get_anchors(string file) | |||||
| { | |||||
| return np.array(File.ReadAllText(file).Split(',') | |||||
| .Select(x => float.Parse(x)) | |||||
| .ToArray()).reshape(3, 3, 2); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,65 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| { | |||||
| public class YOLOv3 | |||||
| { | |||||
| Config cfg; | |||||
| Tensor trainable; | |||||
| Tensor input_data; | |||||
| Dictionary<int, string> classes; | |||||
| int num_class; | |||||
| NDArray strides; | |||||
| NDArray anchors; | |||||
| int anchor_per_scale; | |||||
| float iou_loss_thresh; | |||||
| string upsample_method; | |||||
| Tensor conv_lbbox; | |||||
| Tensor conv_mbbox; | |||||
| Tensor conv_sbbox; | |||||
| public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_) | |||||
| { | |||||
| cfg = cfg_; | |||||
| input_data = input_data_; | |||||
| trainable = trainable_; | |||||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||||
| num_class = len(classes); | |||||
| strides = np.array(cfg.YOLO.STRIDES); | |||||
| anchors = Utils.get_anchors(cfg.YOLO.ANCHORS); | |||||
| anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; | |||||
| iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH; | |||||
| upsample_method = cfg.YOLO.UPSAMPLE_METHOD; | |||||
| (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data); | |||||
| tf_with(tf.variable_scope("pred_sbbox"), scope => | |||||
| { | |||||
| // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); | |||||
| }); | |||||
| tf_with(tf.variable_scope("pred_mbbox"), scope => | |||||
| { | |||||
| // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); | |||||
| }); | |||||
| tf_with(tf.variable_scope("pred_lbbox"), scope => | |||||
| { | |||||
| // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); | |||||
| }); | |||||
| } | |||||
| private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data) | |||||
| { | |||||
| Tensor route_1, route_2; | |||||
| (route_1, route_2, input_data) = backbone.darknet53(input_data, trainable); | |||||
| return (conv_lbbox, conv_mbbox, conv_sbbox); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,28 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| { | |||||
| class backbone | |||||
| { | |||||
| public static (Tensor, Tensor, Tensor) darknet53(Tensor input_data, Tensor trainable) | |||||
| { | |||||
| return tf_with(tf.variable_scope("darknet"), scope => | |||||
| { | |||||
| input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 3, 32 }, trainable: trainable, name: "conv0"); | |||||
| input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 32, 64 }, trainable: trainable, name: "conv1", downsample: true); | |||||
| foreach (var i in range(1)) | |||||
| input_data = common.residual_block(input_data, 64, 32, 64, trainable: trainable, name: $"residual{i + 0}"); | |||||
| var route_1 = input_data; | |||||
| var route_2 = input_data; | |||||
| return (route_1, route_2, input_data); | |||||
| }); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,72 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| { | |||||
| class common | |||||
| { | |||||
| public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable, | |||||
| string name, bool downsample = false, bool activate = true, | |||||
| bool bn = true) | |||||
| { | |||||
| return tf_with(tf.variable_scope(name), scope => | |||||
| { | |||||
| int[] strides; | |||||
| string padding; | |||||
| if (downsample) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| else | |||||
| { | |||||
| strides = new int[] { 1, 1, 1, 1 }; | |||||
| padding = "SAME"; | |||||
| } | |||||
| var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true, | |||||
| shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f)); | |||||
| var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding); | |||||
| if (bn) | |||||
| { | |||||
| conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer, | |||||
| gamma_initializer: tf.ones_initializer, | |||||
| moving_mean_initializer: tf.zeros_initializer, | |||||
| moving_variance_initializer: tf.ones_initializer, training: trainable); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| if (activate) | |||||
| conv = tf.nn.leaky_relu(conv, alpha: 0.1f); | |||||
| return conv; | |||||
| }); | |||||
| } | |||||
| public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1, | |||||
| int filter_num2, Tensor trainable, string name) | |||||
| { | |||||
| var short_cut = input_data; | |||||
| return tf_with(tf.variable_scope(name), scope => | |||||
| { | |||||
| input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 }, | |||||
| trainable: trainable, name: "conv1"); | |||||
| input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 }, | |||||
| trainable: trainable, name: "conv2"); | |||||
| var residual_output = input_data + short_cut; | |||||
| return residual_output; | |||||
| }); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,94 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Text; | |||||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| { | |||||
| public class Config | |||||
| { | |||||
| public YoloConfig YOLO; | |||||
| public TrainConfig TRAIN; | |||||
| public TestConfig TEST; | |||||
| public Config(string root) | |||||
| { | |||||
| YOLO = new YoloConfig(root); | |||||
| TRAIN = new TrainConfig(root); | |||||
| TEST = new TestConfig(root); | |||||
| } | |||||
| public class YoloConfig | |||||
| { | |||||
| string _root; | |||||
| public string CLASSES; | |||||
| public string ANCHORS; | |||||
| public float MOVING_AVE_DECAY = 0.9995f; | |||||
| public int[] STRIDES = new int[] { 8, 16, 32 }; | |||||
| public int ANCHOR_PER_SCALE = 3; | |||||
| public float IOU_LOSS_THRESH = 0.5f; | |||||
| public string UPSAMPLE_METHOD = "resize"; | |||||
| public string ORIGINAL_WEIGHT; | |||||
| public string DEMO_WEIGHT; | |||||
| public YoloConfig(string root) | |||||
| { | |||||
| _root = root; | |||||
| CLASSES = Path.Combine(_root, "data", "classes", "coco.names"); | |||||
| ANCHORS = Path.Combine(_root, "data", "anchors", "basline_anchors.txt"); | |||||
| ORIGINAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco.ckpt"); | |||||
| DEMO_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); | |||||
| } | |||||
| } | |||||
| public class TrainConfig | |||||
| { | |||||
| string _root; | |||||
| public int BATCH_SIZE = 6; | |||||
| public int[] INPUT_SIZE = new int[] { 320, 352, 384, 416, 448, 480, 512, 544, 576, 608 }; | |||||
| public bool DATA_AUG = true; | |||||
| public float LEARN_RATE_INIT = 1e-4f; | |||||
| public float LEARN_RATE_END = 1e-6f; | |||||
| public int WARMUP_EPOCHS = 2; | |||||
| public int FISRT_STAGE_EPOCHS = 20; | |||||
| public int SECOND_STAGE_EPOCHS = 30; | |||||
| public string INITIAL_WEIGHT; | |||||
| public string ANNOT_PATH; | |||||
| public TrainConfig(string root) | |||||
| { | |||||
| _root = root; | |||||
| INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt"); | |||||
| ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); | |||||
| } | |||||
| } | |||||
| public class TestConfig | |||||
| { | |||||
| string _root; | |||||
| public int BATCH_SIZE = 2; | |||||
| public int[] INPUT_SIZE = new int[] { 544 }; | |||||
| public bool DATA_AUG = false; | |||||
| public bool WRITE_IMAGE = true; | |||||
| public string WRITE_IMAGE_PATH; | |||||
| public string WEIGHT_FILE; | |||||
| public bool WRITE_IMAGE_SHOW_LABEL = true; | |||||
| public bool SHOW_LABEL = true; | |||||
| public int SECOND_STAGE_EPOCHS = 30; | |||||
| public float SCORE_THRESHOLD = 0.3f; | |||||
| public float IOU_THRESHOLD = 0.45f; | |||||
| public string ANNOT_PATH; | |||||
| public TestConfig(string root) | |||||
| { | |||||
| _root = root; | |||||
| ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_test.txt"); | |||||
| WRITE_IMAGE_PATH = Path.Combine(_root, "data", "detection"); | |||||
| WEIGHT_FILE = Path.Combine(_root, "checkpoint", "yolov3_test_loss=9.2099.ckpt-5"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -6,6 +6,14 @@ | |||||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||||
| <OutputPath>bin\debug-gpu</OutputPath> | |||||
| </PropertyGroup> | |||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||||
| <OutputPath>bin\release-gpu</OutputPath> | |||||
| </PropertyGroup> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Colorful.Console" Version="1.2.9" /> | <PackageReference Include="Colorful.Console" Version="1.2.9" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | ||||