| @@ -431,17 +431,19 @@ namespace Tensorflow | |||||
| /// <param name="input"></param> | /// <param name="input"></param> | ||||
| /// <param name="axis"></param> | /// <param name="axis"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null) | |||||
| public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, | |||||
| bool keepdims = false, string name = null) | |||||
| { | { | ||||
| if(!axis.HasValue && reduction_indices.HasValue) | if(!axis.HasValue && reduction_indices.HasValue) | ||||
| return math_ops.reduce_sum(input, reduction_indices.Value); | return math_ops.reduce_sum(input, reduction_indices.Value); | ||||
| else if (axis.HasValue && !reduction_indices.HasValue) | else if (axis.HasValue && !reduction_indices.HasValue) | ||||
| return math_ops.reduce_sum(input, axis.Value); | return math_ops.reduce_sum(input, axis.Value); | ||||
| return math_ops.reduce_sum(input); | |||||
| return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||||
| } | } | ||||
| public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null) | |||||
| => math_ops.reduce_sum(input, axis); | |||||
| public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null, | |||||
| bool keepdims = false, string name = null) | |||||
| => math_ops.reduce_sum(input, axis, keepdims: keepdims, name: name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the maximum of elements across dimensions of a tensor. | /// Computes the maximum of elements across dimensions of a tensor. | ||||
| @@ -0,0 +1,17 @@ | |||||
| using System; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class LookupError : TensorflowException | |||||
| { | |||||
| public LookupError() : base() | |||||
| { | |||||
| } | |||||
| public LookupError(string message) : base(message) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,33 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| /// <summary> | |||||
| /// REGISTER_NO_GRADIENT_OP(""); | |||||
| /// </summary> | |||||
| public class RegisterNoGradient : Attribute | |||||
| { | |||||
| public string Name { get; set; } | |||||
| public RegisterNoGradient(string name) | |||||
| { | |||||
| Name = name; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -117,19 +117,44 @@ namespace Tensorflow | |||||
| Tensor[] in_grads = null; | Tensor[] in_grads = null; | ||||
| var is_partitioned_call = _IsPartitionedCall(op); | var is_partitioned_call = _IsPartitionedCall(op); | ||||
| var is_func_call = false; | var is_func_call = false; | ||||
| var has_out_grads = true; | |||||
| var has_out_grads = out_grads.Exists(x => x != null); | |||||
| if (has_out_grads && !stop_ops.Contains(op)) | if (has_out_grads && !stop_ops.Contains(op)) | ||||
| { | { | ||||
| // A grad_fn must be defined, either as a function or as None | // A grad_fn must be defined, either as a function or as None | ||||
| // for ops that do not have gradients. | // for ops that do not have gradients. | ||||
| var grad_fn = ops.get_gradient_function(op); | |||||
| if (is_func_call) | |||||
| Func<Operation, Tensor[], Tensor[]> grad_fn = null; | |||||
| try | |||||
| { | { | ||||
| grad_fn = ops.get_gradient_function(op); | |||||
| } | |||||
| catch (LookupError) | |||||
| { | |||||
| if (is_func_call) | |||||
| { | |||||
| if (is_partitioned_call) | |||||
| { | |||||
| } | |||||
| else | |||||
| { | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); | |||||
| } | |||||
| } | } | ||||
| else | |||||
| // if (loop_state) | |||||
| //loop_state.EnterGradWhileContext(op, before: false); | |||||
| if ((is_func_call || grad_fn != null) && has_out_grads) | |||||
| { | { | ||||
| // NOTE: If _AggregatedGrads didn't compute a value for the i'th | |||||
| // output, it means that the cost does not depend on output[i], | |||||
| // therefore dC/doutput[i] is 0. | |||||
| foreach (var (i, out_grad) in enumerate(out_grads)) | foreach (var (i, out_grad) in enumerate(out_grads)) | ||||
| { | { | ||||
| if (out_grad == null) | if (out_grad == null) | ||||
| @@ -143,13 +168,11 @@ namespace Tensorflow | |||||
| tf_with(ops.name_scope(op.name + "_grad"), scope1 => | tf_with(ops.name_scope(op.name + "_grad"), scope1 => | ||||
| { | { | ||||
| string name1 = scope1; | |||||
| if (grad_fn != null) | if (grad_fn != null) | ||||
| { | { | ||||
| in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); | in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); | ||||
| _VerifyGeneratedGradients(in_grads, op); | |||||
| } | } | ||||
| _VerifyGeneratedGradients(in_grads, op); | |||||
| if (gate_gradients && in_grads.Count(x => x != null) > 1) | if (gate_gradients && in_grads.Count(x => x != null) > 1) | ||||
| { | { | ||||
| ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); | ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); | ||||
| @@ -157,6 +180,12 @@ namespace Tensorflow | |||||
| } | } | ||||
| }); | }); | ||||
| } | } | ||||
| else | |||||
| { | |||||
| // If no grad_fn is defined or none of out_grads is available, | |||||
| // just propagate a list of None backwards. | |||||
| in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | |||||
| } | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -96,12 +96,11 @@ namespace Tensorflow.Gradients | |||||
| }); | }); | ||||
| } | } | ||||
| [RegisterGradient("GreaterEqual")] | |||||
| public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| var grad = grads[0]; | |||||
| throw new NotImplementedException("_GreaterEqualGrad"); | |||||
| } | |||||
| [RegisterNoGradient("GreaterEqual")] | |||||
| public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null; | |||||
| [RegisterNoGradient("ZerosLike")] | |||||
| public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null; | |||||
| [RegisterGradient("Identity")] | [RegisterGradient("Identity")] | ||||
| public static Tensor[] _IdGrad(Operation op, Tensor[] grads) | public static Tensor[] _IdGrad(Operation op, Tensor[] grads) | ||||
| @@ -415,7 +414,9 @@ namespace Tensorflow.Gradients | |||||
| var rank = input_0_shape.Length; | var rank = input_0_shape.Length; | ||||
| if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.Data<int>())) | if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.Data<int>())) | ||||
| { | { | ||||
| grad = array_ops.reshape(grad, new int[] { 1 }); | |||||
| var new_shape = range(rank).Select(x => 1).ToArray(); | |||||
| grad = array_ops.reshape(grad, new_shape); | |||||
| // If shape is not fully defined (but rank is), we use Shape. | |||||
| if (!input_0_shape.Contains(-1)) | if (!input_0_shape.Contains(-1)) | ||||
| input_shape = constant_op.constant(input_0_shape); | input_shape = constant_op.constant(input_0_shape); | ||||
| else | else | ||||
| @@ -39,6 +39,14 @@ namespace Tensorflow | |||||
| gradientFunctions[name] = func; | gradientFunctions[name] = func; | ||||
| } | } | ||||
| public static void RegisterNoGradientFunction(string name) | |||||
| { | |||||
| if (gradientFunctions == null) | |||||
| gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>(); | |||||
| gradientFunctions[name] = null; | |||||
| } | |||||
| public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) | public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) | ||||
| { | { | ||||
| if (op.inputs == null) return null; | if (op.inputs == null) return null; | ||||
| @@ -68,11 +76,18 @@ namespace Tensorflow | |||||
| args: new object[] { oper, out_grads }) as Tensor[] | args: new object[] { oper, out_grads }) as Tensor[] | ||||
| ); | ); | ||||
| } | } | ||||
| // REGISTER_NO_GRADIENT_OP | |||||
| methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterNoGradient>() != null) | |||||
| .ToArray(); | |||||
| foreach (var m in methods) | |||||
| RegisterNoGradientFunction(m.GetCustomAttribute<RegisterNoGradient>().Name); | |||||
| } | } | ||||
| } | } | ||||
| if (!gradientFunctions.ContainsKey(op.type)) | if (!gradientFunctions.ContainsKey(op.type)) | ||||
| throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}"); | |||||
| throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||||
| return gradientFunctions[op.type]; | return gradientFunctions[op.type]; | ||||
| } | } | ||||
| @@ -154,7 +154,7 @@ namespace Tensorflow | |||||
| public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) | public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "", new { }), scope => | |||||
| return tf_with(ops.name_scope(name, "logistic_loss", new { logits, labels }), scope => | |||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| logits = ops.convert_to_tensor(logits, name: "logits"); | logits = ops.convert_to_tensor(logits, name: "logits"); | ||||
| @@ -19,7 +19,7 @@ | |||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.11.4.0</AssemblyVersion> | <AssemblyVersion>0.11.4.0</AssemblyVersion> | ||||
| <PackageReleaseNotes>Changes since v0.10.0: | <PackageReleaseNotes>Changes since v0.10.0: | ||||
| 1. Upgrade NumSharp to v0.20. | |||||
| 1. Upgrade NumSharp to v0.20.3. | |||||
| 2. Add DisposableObject class to manage object lifetime. | 2. Add DisposableObject class to manage object lifetime. | ||||
| 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. | 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. | ||||
| 4. Change tensorflow to non-static class in order to execute some initialization process. | 4. Change tensorflow to non-static class in order to execute some initialization process. | ||||
| @@ -28,7 +28,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| 7. Add tf.image related APIs. | 7. Add tf.image related APIs. | ||||
| 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. | 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. | ||||
| 9. MultiThread is safe. | 9. MultiThread is safe. | ||||
| 10. Support n-dim indexing for tensor.</PackageReleaseNotes> | |||||
| 10. Support n-dim indexing for tensor. | |||||
| 11. Add RegisterNoGradient</PackageReleaseNotes> | |||||
| <LangVersion>7.3</LangVersion> | <LangVersion>7.3</LangVersion> | ||||
| <FileVersion>0.11.4.0</FileVersion> | <FileVersion>0.11.4.0</FileVersion> | ||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| @@ -62,7 +63,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | ||||
| <PackageReference Include="NumSharp" Version="0.20.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.20.3" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -2,7 +2,7 @@ | |||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <OutputType>Exe</OutputType> | <OutputType>Exe</OutputType> | ||||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||||
| <NoWin32Manifest>true</NoWin32Manifest> | <NoWin32Manifest>true</NoWin32Manifest> | ||||
| <AssemblyName>TensorFlowBenchmark</AssemblyName> | <AssemblyName>TensorFlowBenchmark</AssemblyName> | ||||
| <RootNamespace>TensorFlowBenchmark</RootNamespace> | <RootNamespace>TensorFlowBenchmark</RootNamespace> | ||||
| @@ -0,0 +1,154 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.Examples | |||||
| { | |||||
| /// <summary> | |||||
| /// How to optimise your input pipeline with queues and multi-threading | |||||
| /// https://blog.metaflow.fr/tensorflow-how-to-optimise-your-input-pipeline-with-queues-and-multi-threading-e7c3874157e0 | |||||
| /// </summary> | |||||
| public class FullyConnected : IExample | |||||
| { | |||||
| public bool Enabled { get; set; } = true; | |||||
| public bool IsImportingGraph { get; set; } | |||||
| public string Name => "Fully Connected Neural Network"; | |||||
| Tensor input = null; | |||||
| Tensor x_inputs_data = null; | |||||
| Tensor y_inputs_data = null; | |||||
| Tensor accuracy = null; | |||||
| Tensor y_true = null; | |||||
| Tensor loss_op = null; | |||||
| Operation train_op = null; | |||||
| public Graph BuildGraph() | |||||
| { | |||||
| var g = tf.get_default_graph(); | |||||
| // batches of 128 samples, each containing 1024 data points | |||||
| x_inputs_data = tf.random_normal(new[] { 128, 1024 }, mean: 0, stddev: 1); | |||||
| // We will try to predict this law: | |||||
| // predict 1 if the sum of the elements is positive and 0 otherwise | |||||
| y_inputs_data = tf.cast(tf.reduce_sum(x_inputs_data, axis: 1, keepdims: true) > 0, tf.int32); | |||||
| Tensor z = null; | |||||
| tf_with(tf.variable_scope("placeholder"), delegate | |||||
| { | |||||
| input = tf.placeholder(tf.float32, shape: (-1, 1024)); | |||||
| y_true = tf.placeholder(tf.int32, shape: (-1, 1)); | |||||
| }); | |||||
| tf_with(tf.variable_scope("FullyConnected"), delegate | |||||
| { | |||||
| var w = tf.get_variable("w", shape: (1024, 1024), initializer: tf.random_normal_initializer(stddev: 0.1f)); | |||||
| var b = tf.get_variable("b", shape: 1024, initializer: tf.constant_initializer(0.1)); | |||||
| z = tf.matmul(input, w) + b; | |||||
| var y = tf.nn.relu(z); | |||||
| var w2 = tf.get_variable("w2", shape: (1024, 1), initializer: tf.random_normal_initializer(stddev: 0.1f)); | |||||
| var b2 = tf.get_variable("b2", shape: 1, initializer: tf.constant_initializer(0.1)); | |||||
| z = tf.matmul(y, w2) + b2; | |||||
| }); | |||||
| tf_with(tf.variable_scope("Loss"), delegate | |||||
| { | |||||
| var losses = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(y_true, tf.float32), z); | |||||
| loss_op = tf.reduce_mean(losses); | |||||
| }); | |||||
| tf_with(tf.variable_scope("Accuracy"), delegate | |||||
| { | |||||
| var y_pred = tf.cast(z > 0, tf.int32); | |||||
| accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred, y_true), tf.float32)); | |||||
| // accuracy = tf.Print(accuracy, data =[accuracy], message = "accuracy:") | |||||
| }); | |||||
| // We add the training operation, ... | |||||
| var adam = tf.train.AdamOptimizer(0.01f); | |||||
| train_op = adam.minimize(loss_op, name: "train_op"); | |||||
| return g; | |||||
| } | |||||
| public Graph ImportGraph() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void Predict(Session sess) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void PrepareData() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public bool Run() | |||||
| { | |||||
| var g = BuildGraph(); | |||||
| using (var sess = tf.Session()) | |||||
| Train(sess); | |||||
| return true; | |||||
| } | |||||
| public void Test(Session sess) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void Train(Session sess) | |||||
| { | |||||
| var sw = new Stopwatch(); | |||||
| sw.Start(); | |||||
| // init variables | |||||
| sess.run(tf.global_variables_initializer()); | |||||
| // check the accuracy before training | |||||
| var (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); | |||||
| sess.run(accuracy, (input, x_input), (y_true, y_input)); | |||||
| // training | |||||
| foreach (var i in range(5000)) | |||||
| { | |||||
| // by sampling some input data (fetching) | |||||
| (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); | |||||
| var (_, loss) = sess.run((train_op, loss_op), (input, x_input), (y_true, y_input)); | |||||
| // We regularly check the loss | |||||
| if (i % 500 == 0) | |||||
| print($"iter:{i} - loss:{loss}"); | |||||
| } | |||||
| // Finally, we check our final accuracy | |||||
| (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); | |||||
| sess.run(accuracy, (input, x_input), (y_true, y_input)); | |||||
| print($"Time taken: {sw.Elapsed.TotalSeconds}s"); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,7 @@ | |||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <OutputType>Exe</OutputType> | <OutputType>Exe</OutputType> | ||||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| @@ -2,7 +2,7 @@ | |||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <OutputType>Exe</OutputType> | <OutputType>Exe</OutputType> | ||||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| @@ -1,7 +1,7 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | <Project Sdk="Microsoft.NET.Sdk"> | ||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||||
| <IsPackable>false</IsPackable> | <IsPackable>false</IsPackable> | ||||