| @@ -431,17 +431,19 @@ namespace Tensorflow | |||
| /// <param name="input"></param> | |||
| /// <param name="axis"></param> | |||
| /// <returns></returns> | |||
| public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null) | |||
| public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, | |||
| bool keepdims = false, string name = null) | |||
| { | |||
| if(!axis.HasValue && reduction_indices.HasValue) | |||
| return math_ops.reduce_sum(input, reduction_indices.Value); | |||
| else if (axis.HasValue && !reduction_indices.HasValue) | |||
| return math_ops.reduce_sum(input, axis.Value); | |||
| return math_ops.reduce_sum(input); | |||
| return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||
| } | |||
| public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null) | |||
| => math_ops.reduce_sum(input, axis); | |||
| public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null, | |||
| bool keepdims = false, string name = null) | |||
| => math_ops.reduce_sum(input, axis, keepdims: keepdims, name: name); | |||
| /// <summary> | |||
| /// Computes the maximum of elements across dimensions of a tensor. | |||
| @@ -0,0 +1,17 @@ | |||
| using System; | |||
| namespace Tensorflow | |||
| { | |||
| public class LookupError : TensorflowException | |||
| { | |||
| public LookupError() : base() | |||
| { | |||
| } | |||
| public LookupError(string message) : base(message) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| /// <summary> | |||
| /// REGISTER_NO_GRADIENT_OP(""); | |||
| /// </summary> | |||
| public class RegisterNoGradient : Attribute | |||
| { | |||
| public string Name { get; set; } | |||
| public RegisterNoGradient(string name) | |||
| { | |||
| Name = name; | |||
| } | |||
| } | |||
| } | |||
| @@ -117,19 +117,44 @@ namespace Tensorflow | |||
| Tensor[] in_grads = null; | |||
| var is_partitioned_call = _IsPartitionedCall(op); | |||
| var is_func_call = false; | |||
| var has_out_grads = true; | |||
| var has_out_grads = out_grads.Exists(x => x != null); | |||
| if (has_out_grads && !stop_ops.Contains(op)) | |||
| { | |||
| // A grad_fn must be defined, either as a function or as None | |||
| // for ops that do not have gradients. | |||
| var grad_fn = ops.get_gradient_function(op); | |||
| if (is_func_call) | |||
| Func<Operation, Tensor[], Tensor[]> grad_fn = null; | |||
| try | |||
| { | |||
| grad_fn = ops.get_gradient_function(op); | |||
| } | |||
| catch (LookupError) | |||
| { | |||
| if (is_func_call) | |||
| { | |||
| if (is_partitioned_call) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| } | |||
| } | |||
| else | |||
| { | |||
| throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); | |||
| } | |||
| } | |||
| else | |||
| // if (loop_state) | |||
| //loop_state.EnterGradWhileContext(op, before: false); | |||
| if ((is_func_call || grad_fn != null) && has_out_grads) | |||
| { | |||
| // NOTE: If _AggregatedGrads didn't compute a value for the i'th | |||
| // output, it means that the cost does not depend on output[i], | |||
| // therefore dC/doutput[i] is 0. | |||
| foreach (var (i, out_grad) in enumerate(out_grads)) | |||
| { | |||
| if (out_grad == null) | |||
| @@ -143,13 +168,11 @@ namespace Tensorflow | |||
| tf_with(ops.name_scope(op.name + "_grad"), scope1 => | |||
| { | |||
| string name1 = scope1; | |||
| if (grad_fn != null) | |||
| { | |||
| in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); | |||
| _VerifyGeneratedGradients(in_grads, op); | |||
| } | |||
| _VerifyGeneratedGradients(in_grads, op); | |||
| if (gate_gradients && in_grads.Count(x => x != null) > 1) | |||
| { | |||
| ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); | |||
| @@ -157,6 +180,12 @@ namespace Tensorflow | |||
| } | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| // If no grad_fn is defined or none of out_grads is available, | |||
| // just propagate a list of None backwards. | |||
| in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| @@ -96,12 +96,11 @@ namespace Tensorflow.Gradients | |||
| }); | |||
| } | |||
| [RegisterGradient("GreaterEqual")] | |||
| public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) | |||
| { | |||
| var grad = grads[0]; | |||
| throw new NotImplementedException("_GreaterEqualGrad"); | |||
| } | |||
| [RegisterNoGradient("GreaterEqual")] | |||
| public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null; | |||
| [RegisterNoGradient("ZerosLike")] | |||
| public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null; | |||
| [RegisterGradient("Identity")] | |||
| public static Tensor[] _IdGrad(Operation op, Tensor[] grads) | |||
| @@ -415,7 +414,9 @@ namespace Tensorflow.Gradients | |||
| var rank = input_0_shape.Length; | |||
| if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.Data<int>())) | |||
| { | |||
| grad = array_ops.reshape(grad, new int[] { 1 }); | |||
| var new_shape = range(rank).Select(x => 1).ToArray(); | |||
| grad = array_ops.reshape(grad, new_shape); | |||
| // If shape is not fully defined (but rank is), we use Shape. | |||
| if (!input_0_shape.Contains(-1)) | |||
| input_shape = constant_op.constant(input_0_shape); | |||
| else | |||
| @@ -39,6 +39,14 @@ namespace Tensorflow | |||
| gradientFunctions[name] = func; | |||
| } | |||
| public static void RegisterNoGradientFunction(string name) | |||
| { | |||
| if (gradientFunctions == null) | |||
| gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>(); | |||
| gradientFunctions[name] = null; | |||
| } | |||
| public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) | |||
| { | |||
| if (op.inputs == null) return null; | |||
| @@ -68,11 +76,18 @@ namespace Tensorflow | |||
| args: new object[] { oper, out_grads }) as Tensor[] | |||
| ); | |||
| } | |||
| // REGISTER_NO_GRADIENT_OP | |||
| methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterNoGradient>() != null) | |||
| .ToArray(); | |||
| foreach (var m in methods) | |||
| RegisterNoGradientFunction(m.GetCustomAttribute<RegisterNoGradient>().Name); | |||
| } | |||
| } | |||
| if (!gradientFunctions.ContainsKey(op.type)) | |||
| throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}"); | |||
| throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||
| return gradientFunctions[op.type]; | |||
| } | |||
| @@ -154,7 +154,7 @@ namespace Tensorflow | |||
| public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "", new { }), scope => | |||
| return tf_with(ops.name_scope(name, "logistic_loss", new { logits, labels }), scope => | |||
| { | |||
| name = scope; | |||
| logits = ops.convert_to_tensor(logits, name: "logits"); | |||
| @@ -19,7 +19,7 @@ | |||
| Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.11.4.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.10.0: | |||
| 1. Upgrade NumSharp to v0.20. | |||
| 1. Upgrade NumSharp to v0.20.3. | |||
| 2. Add DisposableObject class to manage object lifetime. | |||
| 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. | |||
| 4. Change tensorflow to non-static class in order to execute some initialization process. | |||
| @@ -28,7 +28,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
| 7. Add tf.image related APIs. | |||
| 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. | |||
| 9. MultiThread is safe. | |||
| 10. Support n-dim indexing for tensor.</PackageReleaseNotes> | |||
| 10. Support n-dim indexing for tensor. | |||
| 11. Add RegisterNoGradient</PackageReleaseNotes> | |||
| <LangVersion>7.3</LangVersion> | |||
| <FileVersion>0.11.4.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| @@ -62,7 +63,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <ItemGroup> | |||
| <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | |||
| <PackageReference Include="NumSharp" Version="0.20.2" /> | |||
| <PackageReference Include="NumSharp" Version="0.20.3" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -2,7 +2,7 @@ | |||
| <PropertyGroup> | |||
| <OutputType>Exe</OutputType> | |||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||
| <NoWin32Manifest>true</NoWin32Manifest> | |||
| <AssemblyName>TensorFlowBenchmark</AssemblyName> | |||
| <RootNamespace>TensorFlowBenchmark</RootNamespace> | |||
| @@ -0,0 +1,154 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples | |||
| { | |||
| /// <summary> | |||
| /// How to optimise your input pipeline with queues and multi-threading | |||
| /// https://blog.metaflow.fr/tensorflow-how-to-optimise-your-input-pipeline-with-queues-and-multi-threading-e7c3874157e0 | |||
| /// </summary> | |||
| public class FullyConnected : IExample | |||
| { | |||
| public bool Enabled { get; set; } = true; | |||
| public bool IsImportingGraph { get; set; } | |||
| public string Name => "Fully Connected Neural Network"; | |||
| Tensor input = null; | |||
| Tensor x_inputs_data = null; | |||
| Tensor y_inputs_data = null; | |||
| Tensor accuracy = null; | |||
| Tensor y_true = null; | |||
| Tensor loss_op = null; | |||
| Operation train_op = null; | |||
| public Graph BuildGraph() | |||
| { | |||
| var g = tf.get_default_graph(); | |||
| // batches of 128 samples, each containing 1024 data points | |||
| x_inputs_data = tf.random_normal(new[] { 128, 1024 }, mean: 0, stddev: 1); | |||
| // We will try to predict this law: | |||
| // predict 1 if the sum of the elements is positive and 0 otherwise | |||
| y_inputs_data = tf.cast(tf.reduce_sum(x_inputs_data, axis: 1, keepdims: true) > 0, tf.int32); | |||
| Tensor z = null; | |||
| tf_with(tf.variable_scope("placeholder"), delegate | |||
| { | |||
| input = tf.placeholder(tf.float32, shape: (-1, 1024)); | |||
| y_true = tf.placeholder(tf.int32, shape: (-1, 1)); | |||
| }); | |||
| tf_with(tf.variable_scope("FullyConnected"), delegate | |||
| { | |||
| var w = tf.get_variable("w", shape: (1024, 1024), initializer: tf.random_normal_initializer(stddev: 0.1f)); | |||
| var b = tf.get_variable("b", shape: 1024, initializer: tf.constant_initializer(0.1)); | |||
| z = tf.matmul(input, w) + b; | |||
| var y = tf.nn.relu(z); | |||
| var w2 = tf.get_variable("w2", shape: (1024, 1), initializer: tf.random_normal_initializer(stddev: 0.1f)); | |||
| var b2 = tf.get_variable("b2", shape: 1, initializer: tf.constant_initializer(0.1)); | |||
| z = tf.matmul(y, w2) + b2; | |||
| }); | |||
| tf_with(tf.variable_scope("Loss"), delegate | |||
| { | |||
| var losses = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(y_true, tf.float32), z); | |||
| loss_op = tf.reduce_mean(losses); | |||
| }); | |||
| tf_with(tf.variable_scope("Accuracy"), delegate | |||
| { | |||
| var y_pred = tf.cast(z > 0, tf.int32); | |||
| accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred, y_true), tf.float32)); | |||
| // accuracy = tf.Print(accuracy, data =[accuracy], message = "accuracy:") | |||
| }); | |||
| // We add the training operation, ... | |||
| var adam = tf.train.AdamOptimizer(0.01f); | |||
| train_op = adam.minimize(loss_op, name: "train_op"); | |||
| return g; | |||
| } | |||
| public Graph ImportGraph() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public void Predict(Session sess) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public void PrepareData() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public bool Run() | |||
| { | |||
| var g = BuildGraph(); | |||
| using (var sess = tf.Session()) | |||
| Train(sess); | |||
| return true; | |||
| } | |||
| public void Test(Session sess) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public void Train(Session sess) | |||
| { | |||
| var sw = new Stopwatch(); | |||
| sw.Start(); | |||
| // init variables | |||
| sess.run(tf.global_variables_initializer()); | |||
| // check the accuracy before training | |||
| var (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); | |||
| sess.run(accuracy, (input, x_input), (y_true, y_input)); | |||
| // training | |||
| foreach (var i in range(5000)) | |||
| { | |||
| // by sampling some input data (fetching) | |||
| (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); | |||
| var (_, loss) = sess.run((train_op, loss_op), (input, x_input), (y_true, y_input)); | |||
| // We regularly check the loss | |||
| if (i % 500 == 0) | |||
| print($"iter:{i} - loss:{loss}"); | |||
| } | |||
| // Finally, we check our final accuracy | |||
| (x_input, y_input) = sess.run((x_inputs_data, y_inputs_data)); | |||
| sess.run(accuracy, (input, x_input), (y_true, y_input)); | |||
| print($"Time taken: {sw.Elapsed.TotalSeconds}s"); | |||
| } | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| <PropertyGroup> | |||
| <OutputType>Exe</OutputType> | |||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | |||
| </PropertyGroup> | |||
| @@ -2,7 +2,7 @@ | |||
| <PropertyGroup> | |||
| <OutputType>Exe</OutputType> | |||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | |||
| </PropertyGroup> | |||
| @@ -1,7 +1,7 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||
| <IsPackable>false</IsPackable> | |||