diff --git a/README.md b/README.md index 545cea13..8a9df45c 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,8 @@ Example runner will download all the required files like training data and model * [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER) * [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs) +More troubleshooting of running example refer [here](tensorflowlib/README.md). + ### Contribute: Feel like contributing to one of the hottest projects in the Machine Learning field? Want to know how Tensorflow magically creates the computational graph? We appreciate every contribution however small. There are tasks for novices to experts alike, if everyone tackles only a small task the sum of contributions will be huge. diff --git a/docs/_config.yml b/docs/_config.yml new file mode 100644 index 00000000..c4192631 --- /dev/null +++ b/docs/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman \ No newline at end of file diff --git a/graph/InceptionV3.meta b/graph/InceptionV3.meta index 0ded6221..2a11b082 100644 Binary files a/graph/InceptionV3.meta and b/graph/InceptionV3.meta differ diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index cbcfed28..4e733f18 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -36,6 +36,16 @@ namespace Tensorflow public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => array_ops.expand_dims(input, axis, name, dim); + /// + /// Creates a tensor filled with a scalar value. + /// + /// + /// + /// + /// + public static Tensor fill(Tensor dims, T value, string name = null) + => gen_array_ops.fill(dims, value, name: name); + /// /// Return the elements, either from `x` or `y`, depending on the `condition`. /// diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs index 77491d55..0961288c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs +++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs @@ -6,7 +6,7 @@ namespace Tensorflow { public static partial class tf { - public static object gradients(Tensor[] ys, + public static Tensor[] gradients(Tensor[] ys, Tensor[] xs, Tensor[] grad_ys = null, string name = "gradients", @@ -15,7 +15,7 @@ namespace Tensorflow int? aggregation_method = null, Tensor[] stop_gradients = null) { - return gradients_impl._GradientsHelper(ys, + return gradients_util._GradientsHelper(ys, xs, grad_ys, name, @@ -33,7 +33,7 @@ namespace Tensorflow int? aggregation_method = null, Tensor[] stop_gradients = null) { - return gradients_impl._GradientsHelper(new Tensor[] { ys }, + return gradients_util._GradientsHelper(new Tensor[] { ys }, xs, grad_ys, name, @@ -41,5 +41,23 @@ namespace Tensorflow gate_gradients, stop_gradients: stop_gradients); } + + public static Tensor[] gradients(Tensor ys, + Tensor xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int? aggregation_method = null, + Tensor[] stop_gradients = null) + { + return gradients_util._GradientsHelper(new Tensor[] { ys }, + new Tensor[] { xs }, + grad_ys, + name, + colocate_gradients_with_ops, + gate_gradients, + stop_gradients: stop_gradients); + } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index faf0d089..ea35e869 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -142,6 +142,7 @@ namespace Tensorflow var layer = new Dense(units, activation, use_bias: use_bias, + bias_initializer: bias_initializer, kernel_initializer: kernel_initializer); return layer.apply(inputs); diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index bad41103..a8ec223a 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -257,6 +257,16 @@ namespace Tensorflow public static Tensor negative(Tensor x, string name = null) => gen_math_ops.neg(x, name); + /// + /// Divides x / y elementwise (using Python 2 division operator semantics). + /// + /// + /// + /// + /// + public static Tensor div(Tensor x, Tensor y, string name = null) + => math_ops.div(x, y, name: name); + public static Tensor divide(Tensor x, T[] y, string name = null) where T : struct => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index 266d5799..a1b3e1d8 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -23,6 +23,8 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, object initializer = null, // IInitializer or Tensor bool? trainable = null, + bool? use_resource = null, + bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) { @@ -32,6 +34,8 @@ namespace Tensorflow name, shape: shape, dtype: dtype, + use_resource: use_resource, + validate_shape: validate_shape, initializer: initializer, trainable: trainable); } diff --git a/src/TensorFlowNET.Core/Framework/CompositeTensor.cs b/src/TensorFlowNET.Core/Framework/CompositeTensor.cs new file mode 100644 index 00000000..eac74580 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/CompositeTensor.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Framework +{ + /// + /// Abstract base class for Tensor-like objects that are composed from Tensors. + /// + public abstract class CompositeTensor + { + } +} diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs new file mode 100644 index 00000000..0c4f0c8b --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Framework +{ + /// + /// A sparse representation of a set of tensor slices at given indices. + /// + public class IndexedSlices : CompositeTensor + { + Tensor _values; + public Tensor values => _values; + Tensor _indices; + public Tensor indices => _indices; + Tensor _dense_shape; + public Tensor dense_shape => _dense_shape; + + public string name => _values.name; + + public string device => _values.Device; + + public Operation op => _values.op; + + public TF_DataType dtype => _values.dtype; + + public Graph graph => _values.graph; + + public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null) + { + _values = values; + _indices = indices; + _dense_shape = dense_shape; + + _values.Tag = this; + } + + public static implicit operator Tensor(IndexedSlices indexedSlices) + { + return indexedSlices.values; + } + + public static implicit operator IndexedSlices(Tensor tensor) + { + return tensor.Tag as IndexedSlices; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index b7c5494a..4896d4dd 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Framework; using Tensorflow.Operations; using static Tensorflow.Python; @@ -42,9 +43,9 @@ namespace Tensorflow.Gradients return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad }; var concat_dim = op.inputs[dim_index]; - if (end_value_index == -1) - end_value_index = op.inputs.Length - 1; - var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray(); + var input_values = op.inputs._inputs.Skip(start_value_index) + .Take(end_value_index == -1 ? op.inputs.Length - 1 : end_value_index - start_value_index) + .ToArray(); var out_grads = new List(); if (constant_op.is_constant(concat_dim)) @@ -82,20 +83,26 @@ namespace Tensorflow.Gradients new Tensor[] { non_neg_concat_dim, tf.constant(0) }, new Tensor[] { tf.constant(1), tf.constant(-1) }); var squeeze_sizes = array_ops.squeeze(slice); - out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList(); + out_grads = gen_array_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList(); } else { - var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes); + var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes); foreach (var (begin, size) in zip(offset, sizes)) - out_grads.Add(gen_ops.slice(grad, begin, size)); + out_grads.Add(gen_array_ops.slice(grad, begin, size)); } return (end_value_index <= dim_index ? - out_grads.ToArray().Concat(null) : + out_grads.ToArray().Concat(new Tensor[] { null }) : new Tensor[] { null }.Concat(out_grads)).ToArray(); } + [RegisterGradient("ExpandDims")] + public static Tensor[] _ExpandDimsGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { _ReshapeToInput(op, grads[0]), null }; + } + /// /// Extract the shapes of a set of input tensors. /// @@ -122,7 +129,46 @@ namespace Tensorflow.Gradients if (fully_known) return sizes; else - return gen_ops.shape_n(inputs); + return gen_array_ops.shape_n(inputs); + } + + /// + /// Gradient for GatherV2 op. + /// + /// + /// + /// + [RegisterGradient("GatherV2")] + public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var @params = op.inputs[0]; + ops.colocate_with(@params); + + var params_shape = array_ops.shape(@params, out_type: tf.int64); + params_shape = math_ops.cast(params_shape, tf.int32); + + var indices = op.inputs[1]; + var indices_size = array_ops.expand_dims(array_ops.size(indices), 0); + var axis = op.inputs[2]; + var axis_static = tensor_util.constant_value(axis); + + // For axis 0 gathers, build an appropriately shaped IndexedSlices. + if((int)axis_static == 0) + { + var params_tail_shape = params_shape[new NumSharp.Slice(start:1)]; + var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0); + var values = array_ops.reshape(grad, values_shape); + indices = array_ops.reshape(indices, indices_size); + return new Tensor[] + { + new IndexedSlices(values, indices, params_shape), + null, + null + }; + } + + return new Tensor[] { null, null }; } [RegisterGradient("Reshape")] diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 18151ac5..8ad4b44e 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -1,5 +1,4 @@ -using NumSharp; -using System; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -18,487 +17,7 @@ namespace Tensorflow bool gate_gradients = false, int? aggregation_method = null) { - return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); - } - - public static Tensor[] _GradientsHelper(Tensor[] ys, - Tensor[] xs, - Tensor[] grad_ys = null, - string name = "gradients", - bool colocate_gradients_with_ops = false, - bool gate_gradients = false, - int aggregation_method = 0, - Tensor[] stop_gradients = null, - Graph src_graph = null) - { - if (src_graph == null) - src_graph = ops.get_default_graph(); - - // If src_graph is a _FuncGraph (i.e. a function body), gather it and all - // ancestor graphs. This is necessary for correctly handling captured values. - var curr_graph = src_graph; - - if (stop_gradients == null) - stop_gradients = new Tensor[0]; - if (grad_ys == null) - grad_ys = new Tensor[ys.Length]; - - // Iterate over the collected ops. - /** - * grads: op => list of gradients received on each output endpoint of the - * op. The gradients for each endpoint are initially collected as a list. - * When it is time to call the op's gradient function, for each endpoint we - * aggregate the list of received gradients into a Add() Operation if there - * is more than one. - **/ - var grads = new Dictionary(); - - with(ops.name_scope(name, "gradients", - values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => - { - string grad_scope = scope; - // Get a uid for this call to gradients that can be used to help - // cluster ops for compilation. - var gradient_uid = ops.get_default_graph().unique_name("uid"); - ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); - xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); - grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); - - /** - * The approach we take here is as follows: Create a list of all ops in the - * subgraph between the ys and xs. Visit these ops in reverse order of ids - * to ensure that when we visit an op the gradients w.r.t its outputs have - * been collected. Then aggregate these gradients if needed, call the op's - * gradient function, and add the generated gradients to the gradients for - * its input. - **/ - - // Initialize the pending count for ops in the connected subgraph from ys - // to the xs. - var to_ops = ys.Select(x => x.op).ToList(); - var from_ops = xs.Select(x => x.op).ToList(); - var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); - (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); - - foreach(var (y, grad_y) in Python.zip(ys, grad_ys)) - _SetGrad(grads, y, grad_y); - - // Initialize queue with to_ops. - var queue = new Queue(); - // Add the ops in 'to_ops' into the queue. - var to_ops_set = new List(); - foreach (var op in to_ops) - { - // 'ready' handles the case where one output gradient relies on - // another output's gradient. - if (!pending_count.ContainsKey(op.name)) - pending_count[op.name] = 0; - bool ready = pending_count[op.name] == 0; - if(ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) - { - to_ops_set.Add(op); - queue.Enqueue(op); - } - } - - var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); - while(queue.Count > 0) - { - // generate gradient subgraph for op. - var op = queue.Dequeue(); - _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); - //if (loop_state != null) - //loop_state.EnterGradWhileContext(op, before: true); - var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); - - Tensor[] in_grads = null; - var is_partitioned_call = _IsPartitionedCall(op); - var is_func_call = false; - var has_out_grads = true; - if (has_out_grads && !stop_ops.Contains(op)) - { - if (is_func_call) - { - - } - else - { - // A grad_fn must be defined, either as a function or as None - // for ops that do not have gradients. - var grad_fn = ops.get_gradient_function(op); - - foreach(var (i, out_grad) in enumerate(out_grads)) - { - if(out_grad == null) - { - if (loop_state != null) - ; - else - out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i); - } - } - - with(ops.name_scope(op.name + "_grad"), scope1 => - { - string name1 = scope1; - if (grad_fn != null) - { - in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn); - _VerifyGeneratedGradients(in_grads, op); - } - - if (gate_gradients && in_grads.Count(x => x != null) > 1) - { - ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); - in_grads = control_flow_ops.tuple(in_grads); - } - }); - } - } - else - { - in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; - } - - var inputs = _NonEagerInputs(op, xs).ToList(); - foreach (var (t_in, in_grad) in zip(inputs, in_grads)) - { - if(in_grad != null) - { - if(in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE) - { - in_grad.shape = t_in.shape; - } - - _SetGrad(grads, t_in, in_grad); - } - } - - // Update pending count for the inputs of op and enqueue ready ops. - _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); - } - }); - - return xs.Select(x => _GetGrad(grads, x)).ToArray(); - } - - /// - /// Update pending count for the inputs of op and enqueue ready ops. - /// - /// - /// - /// - /// - /// - /// - private static void _UpdatePendingAndEnqueueReady(Dictionary grads, - Operation op, - Queue queue, - Dictionary pending_count, - object loop_state, - Tensor[] xs) - { - foreach(var x in _NonEagerInputs(op, xs)) - { - if (!pending_count.ContainsKey(x.op.name)) - pending_count[x.op.name] = 0; - - pending_count[x.op.name] -= 1; - - var ready = pending_count[x.op.name] == 0; - - if(loop_state != null && !ready) - { - - } - - if (ready) - { - if (control_flow_util.IsLoopExit(x.op)) - { - - } - else - { - queue.Enqueue(x.op); - } - } - } - } - - private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) - { - if (grads.Count() != op.inputs._inputs.Count()) - throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + - $"inputs {op.inputs._inputs.Count()}"); - } - - private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) - { - scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; - return grad_fn(op, out_grads); - } - - private static bool _IsPartitionedCall(Operation op) - { - return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; - } - - private static Tensor[] _AggregatedGrads(Dictionary grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) - { - var out_grads = _GetGrads(grads, op); - var return_grads = new Tensor[out_grads.Length]; - - foreach(var (i, out_grad) in enumerate(out_grads)) - { - if (loop_state != null) - { - - } - - // Aggregate multiple gradients, and convert [] to None. - if (out_grad != null) - { - if (out_grad.Length < 2) - { - string used = "nop"; - return_grads[i] = out_grad[0]; - } - } - } - - return return_grads; - } - - /// - /// The set of ops that terminate the gradient computation. - /// - /// list of Operations. - /// list of Operations never to backprop through. - /// mapping from operation to number of backprop inputs. - /// list of Tensors. - /// The set of operations. - private static Operation[] _StopOps(List from_ops, List stop_gradient_ops, Dictionary pending_count, Tensor[] xs) - { - var stop_ops = new List(); - - foreach(var op in from_ops) - { - bool is_stop_op = true; - foreach(var inp in _NonEagerInputs(op, xs)) - { - if (!pending_count.ContainsKey(inp.op.name)) - pending_count[inp.op.name] = 0; - - if (pending_count[inp.op.name] > 0) - { - is_stop_op = false; - break; - } - } - if (is_stop_op) - stop_ops.Insert(0, op); - } - stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); - return stop_ops.ToArray(); - } - - private static Tensor _GetGrad(Dictionary grads, Tensor t) - { - var op = t.op; - if (!grads.ContainsKey(op.name)) - return null; - Tensor[][] op_grads = grads[op.name]; - var t_grad = op_grads[t.value_index]; - return t_grad[0]; - } - - private static Tensor[][] _GetGrads(Dictionary grads, Operation op) - { - if (grads.ContainsKey(op.name)) - return grads[op.name]; - else - return op.outputs.Select(x => new Tensor[0]).ToArray(); - } - - /// - /// Sets gradient "grad" in "grads" for tensor "t". - /// - /// - /// - /// - private static void _SetGrad(Dictionary grads, Tensor t, Tensor grad) - { - var op = t.op; - Tensor[][] op_grads = grads.ContainsKey(op.name) ? grads[op.name] : null; - if (op_grads == null) - { - op_grads = op.outputs.Select(x => new Tensor[1]).ToArray(); - grads[op.name] = op_grads; - } - var t_grads = op_grads[t.value_index]; - t_grads[0] = grad; - } - - /// - /// Fill in default values for grad_ys. - /// - /// List of gradients, can contain None. - /// List of tensors. - /// - /// - private static Tensor[] _DefaultGradYs(Tensor[] grad_ys, Tensor[] ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") - { - var new_grad_ys = new List(); - - for(int i = 0; i < grad_ys.Length; i++) - { - var grad_y = grad_ys[i]; - var y = ys[i]; - - _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); - - if(grad_y == null) - { - if (y.dtype.is_complex()) - throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); - var shape = array_ops.shape(y); - var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); - var fill = gen_array_ops.fill(shape, constant); - new_grad_ys.Add(fill); - } - } - - return new_grad_ys.ToArray(); - } - - private static void _maybe_colocate_with(Operation op, string gradient_uid, bool colocate_gradients_with_ops) - { - - } - - /// - /// Initialize the pending count for ops between two lists of Operations. - /// 'pending_count[op]' indicates the number of backprop inputs - /// to this operation. - /// - /// - /// - /// - /// - /// - private static (Operation[], Dictionary, object) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) - { - // Mark reachable ops from from_ops. - var reached_ops = new List(); - _MarkReachedOps(from_ops, reached_ops, func_graphs); - // X in reached_ops iff X is reachable from from_ops by a path of zero or more - // backpropagatable tensors. - - var reachable_to_ops = to_ops.Where(x => reached_ops.Contains(x)).Select(x => x).ToArray(); - - var between_ops = new List(); - var between_op_list = new List(); - - Queue queue = new Queue(to_ops); - while(queue.Count > 0) - { - var op = queue.Dequeue(); - if (reached_ops.Contains(op)) - { - between_ops.Add(op); - between_op_list.Insert(between_op_list.Count, op); - // Clear the boolean so we won't add the inputs again. - reached_ops.Remove(op); - foreach (var inp in _NonEagerInputs(op, xs)) - queue.Enqueue(inp.op); - } - } - // X in between_ops iff X is on a path of zero or more backpropagatable tensors - // between from_ops and to_ops - - // 'loop_state' is None if there are no while loops. - var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); - - var pending_count = new Dictionary(); - foreach (var op in between_op_list) - { - foreach(Tensor x in _NonEagerInputs(op, xs)) - { - if (between_ops.Contains(x.op)) - { - if (!pending_count.ContainsKey(x.op.name)) - pending_count[x.op.name] = 0; - - pending_count[x.op.name] += 1; - } - } - } - - return (reachable_to_ops.ToArray(), pending_count, loop_state); - } - - private static IEnumerable _NonEagerInputs(Operation op, Tensor[] xs) - { - for (int i = 0; i < op.inputs.Length; i++) - yield return op.inputs[i]; - } - - /// - /// Mark all ops reached from "from_ops" - /// - /// - /// - /// - private static void _MarkReachedOps(List from_ops, List reached_ops, List func_graphs) - { - Queue queue = new Queue(from_ops); - while (queue.Count > 0) - { - var op = queue.Dequeue(); - - if (!reached_ops.Contains(op)) - { - reached_ops.Add(op); - foreach (var output in op.outputs) - { - if (_IsBackpropagatable(output)) - { - var c = _Consumers(output, func_graphs).ToList(); - c.ForEach(x => queue.Enqueue(x)); - } - } - } - } - } - - private static bool _IsTrainable(Tensor tensor) - { - var dtype = tensor.dtype.as_base_dtype(); - return new TF_DataType[] {TF_DataType.TF_HALF, TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, - TF_DataType.TF_COMPLEX64, TF_DataType.TF_COMPLEX128, TF_DataType.TF_RESOURCE}.Contains(dtype); - } - private static bool _IsBackpropagatable(Tensor tensor) - { - if(_IsTrainable(tensor)) - { - return true; - } - else - { - var dtype = tensor.dtype.as_base_dtype(); - return new TF_DataType[] { TF_DataType.TF_BFLOAT16, TF_DataType.TF_VARIANT }.Contains(dtype); - } - } - - /// - /// Returns the consumers of t, crossing closure boundaries where necessary. - /// - /// - /// - private static Operation[] _Consumers(Tensor t, List func_graphs) - { - return t.consumers(); + return gradients_util._GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); } private static List _AsList(object ys) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs new file mode 100644 index 00000000..12a50479 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -0,0 +1,540 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Python; + +namespace Tensorflow +{ + public class gradients_util + { + public static Tensor[] _GradientsHelper(Tensor[] ys, + Tensor[] xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int aggregation_method = 0, + Tensor[] stop_gradients = null, + Graph src_graph = null) + { + if (src_graph == null) + src_graph = ops.get_default_graph(); + + // If src_graph is a _FuncGraph (i.e. a function body), gather it and all + // ancestor graphs. This is necessary for correctly handling captured values. + var curr_graph = src_graph; + + if (stop_gradients == null) + stop_gradients = new Tensor[0]; + if (grad_ys == null) + grad_ys = new Tensor[ys.Length]; + + // Iterate over the collected ops. + /** + * grads: op => list of gradients received on each output endpoint of the + * op. The gradients for each endpoint are initially collected as a list. + * When it is time to call the op's gradient function, for each endpoint we + * aggregate the list of received gradients into a Add() Operation if there + * is more than one. + **/ + var grads = new Dictionary>>(); + + with(ops.name_scope(name, "gradients", + values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => + { + string grad_scope = scope; + // Get a uid for this call to gradients that can be used to help + // cluster ops for compilation. + var gradient_uid = ops.get_default_graph().unique_name("uid"); + ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); + xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); + grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); + + /** + * The approach we take here is as follows: Create a list of all ops in the + * subgraph between the ys and xs. Visit these ops in reverse order of ids + * to ensure that when we visit an op the gradients w.r.t its outputs have + * been collected. Then aggregate these gradients if needed, call the op's + * gradient function, and add the generated gradients to the gradients for + * its input. + **/ + + // Initialize the pending count for ops in the connected subgraph from ys + // to the xs. + var to_ops = ys.Select(x => x.op).ToList(); + var from_ops = xs.Select(x => x.op).ToList(); + var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); + (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + + foreach (var (y, grad_y) in zip(ys, grad_ys)) + _SetGrad(grads, y, grad_y); + + // Initialize queue with to_ops. + var queue = new Queue(); + // Add the ops in 'to_ops' into the queue. + var to_ops_set = new List(); + foreach (var op in to_ops) + { + // 'ready' handles the case where one output gradient relies on + // another output's gradient. + if (!pending_count.ContainsKey(op.name)) + pending_count[op.name] = 0; + bool ready = pending_count[op.name] == 0; + if (ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) + { + to_ops_set.Add(op); + queue.Enqueue(op); + } + } + + var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); + while (queue.Count > 0) + { + // generate gradient subgraph for op. + var op = queue.Dequeue(); + + _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); + //if (loop_state != null) + //loop_state.EnterGradWhileContext(op, before: true); + var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); + + Tensor[] in_grads = null; + var is_partitioned_call = _IsPartitionedCall(op); + var is_func_call = false; + var has_out_grads = true; + if (has_out_grads && !stop_ops.Contains(op)) + { + if (is_func_call) + { + + } + else + { + // A grad_fn must be defined, either as a function or as None + // for ops that do not have gradients. + var grad_fn = ops.get_gradient_function(op); + + foreach (var (i, out_grad) in enumerate(out_grads)) + { + if (out_grad == null) + { + if (loop_state != null) + ; + else + out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i); + } + } + + with(ops.name_scope(op.name + "_grad"), scope1 => + { + string name1 = scope1; + if (grad_fn != null) + { + in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn); + _VerifyGeneratedGradients(in_grads, op); + } + + if (gate_gradients && in_grads.Count(x => x != null) > 1) + { + ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); + in_grads = control_flow_ops.tuple(in_grads); + } + }); + } + } + else + { + in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; + } + + var inputs = _NonEagerInputs(op, xs).ToList(); + foreach (var (t_in, in_grad) in zip(inputs, in_grads)) + { + if (in_grad != null) + { + if (in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE) + { + in_grad.shape = t_in.shape; + } + + _SetGrad(grads, t_in, in_grad); + } + } + + // Update pending count for the inputs of op and enqueue ready ops. + _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); + } + }); + + return xs.Select(x => _GetGrad(grads, x)).ToArray(); + } + + /// + /// Fill in default values for grad_ys. + /// + /// List of gradients, can contain None. + /// List of tensors. + /// + /// + private static Tensor[] _DefaultGradYs(Tensor[] grad_ys, Tensor[] ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") + { + var new_grad_ys = new List(); + + for (int i = 0; i < grad_ys.Length; i++) + { + var grad_y = grad_ys[i]; + var y = ys[i]; + + _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); + + if (grad_y == null) + { + if (y.dtype.is_complex()) + throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); + var shape = array_ops.shape(y); + var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); + var fill = gen_array_ops.fill(shape, constant); + new_grad_ys.Add(fill); + } + } + + return new_grad_ys.ToArray(); + } + + private static void _maybe_colocate_with(Operation op, string gradient_uid, bool colocate_gradients_with_ops) + { + + } + + /// + /// Initialize the pending count for ops between two lists of Operations. + /// 'pending_count[op]' indicates the number of backprop inputs + /// to this operation. + /// + /// + /// + /// + /// + /// + private static (Operation[], Dictionary, object) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) + { + // Mark reachable ops from from_ops. + var reached_ops = new List(); + _MarkReachedOps(from_ops, reached_ops, func_graphs); + // X in reached_ops iff X is reachable from from_ops by a path of zero or more + // backpropagatable tensors. + + var reachable_to_ops = to_ops.Where(x => reached_ops.Contains(x)).Select(x => x).ToArray(); + + var between_ops = new List(); + var between_op_list = new List(); + + Queue queue = new Queue(to_ops); + while (queue.Count > 0) + { + var op = queue.Dequeue(); + if (reached_ops.Contains(op)) + { + between_ops.Add(op); + between_op_list.Insert(between_op_list.Count, op); + // Clear the boolean so we won't add the inputs again. + reached_ops.Remove(op); + foreach (var inp in _NonEagerInputs(op, xs)) + queue.Enqueue(inp.op); + } + } + // X in between_ops iff X is on a path of zero or more backpropagatable tensors + // between from_ops and to_ops + + // 'loop_state' is None if there are no while loops. + var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); + + var pending_count = new Dictionary(); + foreach (var op in between_op_list) + { + foreach (Tensor x in _NonEagerInputs(op, xs)) + { + if (between_ops.Contains(x.op)) + { + if (!pending_count.ContainsKey(x.op.name)) + pending_count[x.op.name] = 0; + + pending_count[x.op.name] += 1; + } + } + } + + return (reachable_to_ops.ToArray(), pending_count, loop_state); + } + + /// + /// Sets gradient "grad" in "grads" for tensor "t". + /// + /// + /// + /// + private static void _SetGrad(Dictionary>> grads, Tensor t, Tensor grad) + { + var op = t.op; + var op_grads = grads.ContainsKey(op.name) ? grads[op.name] : null; + if (op_grads == null) + { + op_grads = op.outputs.Select(x => new List()).ToList(); + grads[op.name] = op_grads; + } + var t_grads = op_grads[t.value_index]; + t_grads.Add(grad); + } + + private static IEnumerable _NonEagerInputs(Operation op, Tensor[] xs) + { + for (int i = 0; i < op.inputs.Length; i++) + yield return op.inputs[i]; + } + + private static Tensor[] _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) + { + var out_grads = _GetGrads(grads, op); + var return_grads = new Tensor[out_grads.Count]; + + foreach (var (i, out_grad) in enumerate(out_grads)) + { + if (loop_state != null) + { + + } + + // Aggregate multiple gradients, and convert [] to None. + if (out_grad.Count > 0) + { + string used = ""; + if (out_grad.Count < 2) + { + used = "nop"; + if (out_grad.Count == 0) + { + throw new ValueError("_AggregatedGrads out_grad.Length == 0"); + } + + return_grads[i] = out_grad[0]; + } + else + { + used = "add_n"; + out_grads[i] = new List { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) }; + } + } + else + { + return_grads[i] = null; + } + } + + return return_grads; + } + + /// + /// Adds tensors from potentially multiple devices. + /// + /// + /// + /// + private static Tensor _MultiDeviceAddN(Tensor[] tensor_list, string gradient_uid) + { + // Basic function structure comes from control_flow_ops.group(). + // Sort tensors according to their devices. + var tensors_on_device = new Dictionary>(); + + foreach (var tensor in tensor_list) + { + if (!tensors_on_device.ContainsKey(tensor.Device)) + tensors_on_device[tensor.Device] = new List(); + + tensors_on_device[tensor.Device].Add(tensor); + } + + // For each device, add the tensors on that device first. + var summands = new List(); + foreach(var dev in tensors_on_device.Keys) + { + var tensors = tensors_on_device[dev]; + ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true); + summands.Add(math_ops.add_n(tensors.ToArray())); + } + + return math_ops.add_n(summands.ToArray()); + } + + /// + /// The set of ops that terminate the gradient computation. + /// + /// list of Operations. + /// list of Operations never to backprop through. + /// mapping from operation to number of backprop inputs. + /// list of Tensors. + /// The set of operations. + private static Operation[] _StopOps(List from_ops, List stop_gradient_ops, Dictionary pending_count, Tensor[] xs) + { + var stop_ops = new List(); + + foreach (var op in from_ops) + { + bool is_stop_op = true; + foreach (var inp in _NonEagerInputs(op, xs)) + { + if (!pending_count.ContainsKey(inp.op.name)) + pending_count[inp.op.name] = 0; + + if (pending_count[inp.op.name] > 0) + { + is_stop_op = false; + break; + } + } + if (is_stop_op) + stop_ops.Insert(0, op); + } + stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); + return stop_ops.ToArray(); + } + + private static Tensor _GetGrad(Dictionary>> grads, Tensor t) + { + var op = t.op; + if (!grads.ContainsKey(op.name)) + return null; + var op_grads = grads[op.name]; + var t_grad = op_grads[t.value_index]; + return t_grad[0]; + } + + private static List> _GetGrads(Dictionary>> grads, Operation op) + { + if (grads.ContainsKey(op.name)) + return grads[op.name]; + else + return op.outputs.Select(x => new List()).ToList(); + } + + /// + /// Mark all ops reached from "from_ops" + /// + /// + /// + /// + private static void _MarkReachedOps(List from_ops, List reached_ops, List func_graphs) + { + Queue queue = new Queue(from_ops); + while (queue.Count > 0) + { + var op = queue.Dequeue(); + + if (!reached_ops.Contains(op)) + { + reached_ops.Add(op); + foreach (var output in op.outputs) + { + if (_IsBackpropagatable(output)) + { + var c = _Consumers(output, func_graphs).ToList(); + c.ForEach(x => queue.Enqueue(x)); + } + } + } + } + } + + /// + /// Returns the consumers of t, crossing closure boundaries where necessary. + /// + /// + /// + private static Operation[] _Consumers(Tensor t, List func_graphs) + { + return t.consumers(); + } + + private static bool _IsBackpropagatable(Tensor tensor) + { + if (_IsTrainable(tensor)) + { + return true; + } + else + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] { TF_DataType.TF_BFLOAT16, TF_DataType.TF_VARIANT }.Contains(dtype); + } + } + + private static bool _IsTrainable(Tensor tensor) + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] {TF_DataType.TF_HALF, TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, + TF_DataType.TF_COMPLEX64, TF_DataType.TF_COMPLEX128, TF_DataType.TF_RESOURCE}.Contains(dtype); + } + + private static bool _IsPartitionedCall(Operation op) + { + return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; + } + + /// + /// Update pending count for the inputs of op and enqueue ready ops. + /// + /// + /// + /// + /// + /// + /// + private static void _UpdatePendingAndEnqueueReady(Dictionary>> grads, + Operation op, + Queue queue, + Dictionary pending_count, + object loop_state, + Tensor[] xs) + { + foreach (var x in _NonEagerInputs(op, xs)) + { + if (!pending_count.ContainsKey(x.op.name)) + pending_count[x.op.name] = 0; + + pending_count[x.op.name] -= 1; + + var ready = pending_count[x.op.name] == 0; + + if (loop_state != null && !ready) + { + + } + + if (ready) + { + if (control_flow_util.IsLoopExit(x.op)) + { + + } + else + { + queue.Enqueue(x.op); + } + } + } + } + + private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) + { + scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; + return grad_fn(op, out_grads); + } + + private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) + { + if (grads.Count() != op.inputs._inputs.Count()) + throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + + $"inputs {op.inputs._inputs.Count()}"); + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 3f4ab94d..f7f8e35f 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -168,6 +168,96 @@ namespace Tensorflow.Gradients return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null }; } + /// + /// Gradient for Max. + /// + /// + /// + /// + [RegisterGradient("Max")] + public static Tensor[] _MaxGrad(Operation op, Tensor[] grads) + { + return _MinOrMaxGrad(op, grads); + } + + /// + /// Gradient for Min. + /// + /// + /// + /// + [RegisterGradient("Min")] + public static Tensor[] _MinGrad(Operation op, Tensor[] grads) + { + return _MinOrMaxGrad(op, grads); + } + + private static Tensor[] _MinOrMaxGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_shape = array_ops.shape(op.inputs[0]); + var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); + var y = op.outputs[0]; + y = array_ops.reshape(y, output_shape_kept_dims); + grad = array_ops.reshape(grad, output_shape_kept_dims); + + // Compute the number of selected (maximum or minimum) elements in each + // reduction dimension. If there are multiple minimum or maximum elements + // then the gradient will be divided between them. + var indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype); + var num_selected = array_ops.reshape(math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims); + + return new Tensor[] { math_ops.div(indicators, num_selected) * grad, null }; + } + + /// + /// Returns grad*(x > y, x <= y) with type of grad. + /// + /// + /// + /// + [RegisterGradient("Maximum")] + public static Tensor[] _MaximumGrad(Operation op, Tensor[] grads) + { + return _MaximumMinimumGrad(op, grads[0]); + } + + /// + /// Returns grad*(x < y, x >= y) with type of grad. + /// + /// + /// + /// + [RegisterGradient("Minimum")] + public static Tensor[] _MinimumGrad(Operation op, Tensor[] grads) + { + return _MaximumMinimumGrad(op, grads[0]); + } + + /// + /// Factor out the code for the gradient of Maximum or Minimum. + /// + /// + /// + /// + private static Tensor[] _MaximumMinimumGrad(Operation op, Tensor grad) + { + var x = op.inputs[0]; + var y = op.inputs[1]; + var gdtype = grad.dtype; + var sx = array_ops.shape(x); + var sy = array_ops.shape(y); + var gradshape = array_ops.shape(grad); + var zeros = array_ops.zeros(gradshape, gdtype); + var xmask = gen_math_ops.greater_equal(x, y); + var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); + var xgrad = array_ops.where(xmask, grad, zeros); + var ygrad = array_ops.where(xmask, zeros, grad); + var gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx); + var gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy); + return new Tensor[] { gx, gy }; + } + [RegisterGradient("Neg")] public static Tensor[] _NegGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index 19cdac36..7ef39cde 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -106,10 +106,10 @@ namespace Tensorflow.Gradients [RegisterGradient("Conv2D")] public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) { - var dilations = op.get_attr("dilations"); - var strides = op.get_attr("strides"); + var dilations = (op.get_attr("dilations") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); + var strides = (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); var padding = op.get_attr("padding"); - var explicit_paddings = op.get_attr("explicit_paddings"); + var explicit_paddings = (op.get_attr("explicit_paddings") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); var data_format = op.get_attr("data_format"); var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); @@ -120,21 +120,23 @@ namespace Tensorflow.Gradients { InputSizes = shape[0], Filter = op.inputs[1], - Dilations = dilations == null ? null : dilations as int[], - Strides = strides == null ? null : strides as int[], + OutBackProp = grads[0], + Dilations = dilations, + Strides = strides, Padding = padding.ToString(), - ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[], + ExplicitPaddings = explicit_paddings, UseCudnnOnGpu = (bool)use_cudnn_on_gpu, - DataFormat = data_format.ToString() + DataFormat = data_format.ToString(), }), gen_nn_ops.conv2d_backprop_filter(new Conv2dParams { Input = op.inputs[0], FilterSizes = shape[1], - Dilations = dilations == null ? null : dilations as int[], - Strides = strides == null ? null : strides as int[], + OutBackProp = grads[0], + Dilations = dilations, + Strides = strides, Padding = padding.ToString(), - ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[], + ExplicitPaddings = explicit_paddings, UseCudnnOnGpu = (bool)use_cudnn_on_gpu, DataFormat = data_format.ToString() }) @@ -155,6 +157,23 @@ namespace Tensorflow.Gradients return vec * mat; } + [RegisterGradient("MaxPool")] + public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + return new Tensor[] + { + gen_nn_ops.max_pool_grad( + op.inputs[0], + op.outputs[0], + grad, + (op.get_attr("ksize") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), + (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), + padding: op.get_attr("padding").ToString(), + data_format: op.get_attr("data_format").ToString()) + }; + } + /// /// Return the gradients for TopK. /// diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 86d19bbf..8dba03b9 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -179,6 +179,23 @@ namespace Tensorflow.Operations return _op.outputs[0]; } + public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, + string data_format= "NHWC", string name= null) + { + var _op = _op_def_lib._apply_op_helper("MaxPoolGrad", name: name, args: new + { + orig_input, + orig_output, + grad, + ksize, + strides, + padding, + data_format + }); + + return _op.outputs[0]; + } + public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null) { var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index bbb62be2..c5bd77f6 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,5 +1,7 @@ using Google.Protobuf.Collections; -//using Newtonsoft.Json; +#if GRAPH_SERIALIZE +using Newtonsoft.Json; +#endif using System; using System.Collections.Generic; using System.Linq; @@ -33,25 +35,34 @@ namespace Tensorflow private readonly IntPtr _operDesc; private Graph _graph; - //[JsonIgnore] + public string type => OpType; + +#if GRAPH_SERIALIZE + [JsonIgnore] + public Graph graph => _graph; + [JsonIgnore] + public int _id => _id_value; + [JsonIgnore] + public int _id_value; + [JsonIgnore] + public Operation op => this; +#else public Graph graph => _graph; - //[JsonIgnore] public int _id => _id_value; - //[JsonIgnore] public int _id_value; - - public string type => OpType; - //[JsonIgnore] public Operation op => this; +#endif public TF_DataType dtype => TF_DataType.DtInvalid; private Status status = new Status(); - public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); + public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); private NodeDef _node_def; - //[JsonIgnore] +#if GRAPH_SERIALIZE + [JsonIgnore] +#endif public NodeDef node_def { get diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 0e043198..c997f179 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -36,6 +36,29 @@ namespace Tensorflow }); } + public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return with(ops.name_scope(name, "zeros", shape), scope => + { + name = scope; + switch (dtype) + { + case TF_DataType.TF_BOOL: + return gen_array_ops.fill(shape, tf.constant(false, dtype: dtype), name: name); + case TF_DataType.TF_DOUBLE: + return gen_array_ops.fill(shape, tf.constant(0.0D, dtype: dtype), name: name); + case TF_DataType.TF_FLOAT: + return gen_array_ops.fill(shape, tf.constant(0.0F, dtype: dtype), name: name); + case TF_DataType.TF_INT32: + return gen_array_ops.fill(shape, tf.constant(0, dtype: dtype), name: name); + default: + throw new TypeError("can't find type for zeros"); + } + + }); + } + private static Tensor _constant_if_small(int value, Tensor shape) { return shape < 1000; @@ -127,8 +150,28 @@ namespace Tensorflow private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name); + /// + /// Returns the rank of a tensor. + /// + /// + /// + /// public static Tensor rank(Tensor input, string name = null) - => math_ops.rank_internal(input, name, optimize: true); + => rank_internal(input, name, optimize: true); + + public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) + { + return with(ops.name_scope(name, "Rank", new List { input }), scope => + { + name = scope; + var input_tensor = ops.convert_to_tensor(input); + var input_shape = tensor_util.to_shape(input_tensor.shape); + if (optimize && input_shape.NDim > 0) + return constant_op.constant(input_shape.NDim, dtype: tf.int32, name: name); + else + return gen_array_ops.rank(input, name); + }); + } /// /// Creates a tensor with all elements set to 1. @@ -233,6 +276,9 @@ namespace Tensorflow }); } + public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) + => gen_array_ops.unique(x, out_idx: out_idx, name: name); + public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) { if( x == null && y == null) @@ -277,7 +323,7 @@ namespace Tensorflow var input_shape = tensor_util.to_shape(input_tensor.shape); if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) { - var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); + var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); return constant_op.constant(nd, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 259fee26..d543b380 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -123,7 +123,7 @@ namespace Tensorflow return with(ops.name_scope(name, "tuple", tensors), scope => { name = scope; - var gating_ops = tensors.Select(x => x.op).ToList(); + var gating_ops = tensors.Where(x => x != null).Select(x => x.op).ToList(); if(control_inputs != null) { @@ -139,7 +139,10 @@ namespace Tensorflow var tpl = new List(); foreach(var t in tensors) { - tpl.Add(with_dependencies(new Operation[] { gate }, t)); + if (t != null) + tpl.Add(with_dependencies(new Operation[] { gate }, t)); + else + tpl.Add(null); } return tpl.ToArray(); diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index fb980259..8308d48d 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -26,6 +26,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ConcatOffset", name: name, args: new { concat_dim, shape }); + + return _op.outputs; + } + /// /// Returns a diagonal tensor with a given diagonal values. /// @@ -205,6 +212,21 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Finds unique elements in a 1-D tensor. + /// + /// + /// + /// + /// + public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx }); + // TODO + //var _result = _UniqueOutput._make(_op.outputs); + return (_op.outputs[0], _op.outputs[1]); + } + public static Tensor where() { throw new NotImplementedException("where"); @@ -271,6 +293,26 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Return a slice from 'input' + /// + /// + /// + /// + /// + /// + public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); + return _op.outputs[0]; + } + + public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); + return _op.outputs; + } + public static Tensor tile(Tensor input, Tensor multiples, string name = null) { var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index e5670dd0..763a4bd8 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -16,6 +16,19 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Add all input tensors element wise. + /// + /// + /// + /// + public static Tensor add_n(Tensor[] inputs, string name = null) + { + var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); + + return _op.outputs[0]; + } + /// /// Returns the index with the largest value across dimensions of a tensor. /// @@ -198,6 +211,20 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null) + { + var _op = _op_def_lib._apply_op_helper("UnsortedSegmentSum", name, new { data, segment_ids, num_segments }); + return _op.outputs[0]; + } + public static Tensor tan(Tensor x, string name = null) { var _op = _op_def_lib._apply_op_helper("Tan", name, args: new { x }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index f37bd0dd..29e9d671 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -44,8 +44,8 @@ namespace Tensorflow return array_ops.identity(values, name: name); return values; } - throw new NotImplementedException("math_ops add_n n > 1"); - // return gen_math_ops.add_n(inputs, name: name); + + return gen_math_ops.add_n(inputs, name: name); } public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) @@ -65,6 +65,31 @@ namespace Tensorflow }); } + /// + /// Divide two values using Python 2 semantics. Used for Tensor.__div__. + /// + /// `Tensor` numerator of real numeric type. + /// `Tensor` denominator of real numeric type. + /// A name for the operation + /// `x / y` returns the quotient of x and y. + public static Tensor div(Tensor x, Tensor y, string name = null) + { + return with(ops.name_scope(name, "div", (x, y)), name_scope => + { + name = name_scope; + x = ops.convert_to_tensor(x, name: "x"); + y = ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name = "y"); + var x_dtype = x.dtype.as_base_dtype(); + var y_dtype = y.dtype.as_base_dtype(); + if (x_dtype != y_dtype) + throw new TypeError($"x and y must have the same dtype, got {x_dtype} != {y_dtype}"); + if (x_dtype.is_floating() || x_dtype.is_complex()) + return gen_math_ops.real_div(x, y, name: name); + else + return gen_math_ops.floor_div(x, y, name: name); + }); + } + /// /// Returns 0 if the denominator is zero. /// @@ -101,6 +126,9 @@ namespace Tensorflow public static Tensor equal(Tx x, Ty y, string name = null) => gen_math_ops.equal(x, y, name: name); + public static Tensor sqrt(Tensor x, string name = null) + => gen_math_ops.sqrt(x, name: name); + public static Tensor multiply(Tx x, Ty y, string name = null) => gen_math_ops.mul(x, y, name: name); @@ -294,6 +322,17 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis, min); } + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null) + => gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments, name: name); + /// /// Casts a tensor to type `int32`. /// @@ -429,20 +468,6 @@ namespace Tensorflow }); } - public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) - { - return with(ops.name_scope(name, "Rank", new List { input }), scope => - { - name = scope; - var input_tensor = ops.convert_to_tensor(input); - var input_shape = tensor_util.to_shape(input_tensor.shape); - if (optimize && input_shape.NDim == null) - return constant_op.constant(input_shape.NDim); - else - return gen_array_ops.rank(input, name); - }); - } - public static Tensor maximum(Tx x, Ty y, string name = null) => gen_math_ops.maximum(x, y, name: name); diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index f8ecf7b9..eb48c5bc 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -16,5 +16,10 @@ namespace Tensorflow value_tensor, name: name); } + + public static bool is_resource_variable(VariableV1 var) + { + return var is ResourceVariable; + } } } diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 71ae49c0..1e8a63f1 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,10 +5,10 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.8.1 + 0.8.2 Haiping Chen SciSharp STACK - true + false Apache 2.0 https://github.com/SciSharp/TensorFlow.NET git @@ -17,19 +17,20 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.8.1.0 + 0.8.2.0 Changes since v0.8: 1. Remove global static graph instance. 2. Provide custom gradient function. -3. Add gradient function for Conv2D. +3. Add gradient function for Conv2D. +4. Fix bug for Transfer Learning example. 7.2 - 0.8.1.0 + 0.8.2.0 true - DEBUG;TRACE + TRACE;DEBUG @@ -43,18 +44,17 @@ Docs: https://tensorflownet.readthedocs.io - + - - + - - - + + + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 5e3d611b..f782451f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -1,4 +1,6 @@ -//using Newtonsoft.Json; +#if GRAPH_SERIALIZE +using Newtonsoft.Json; +#endif using NumSharp; using System; using System.Collections.Generic; @@ -19,15 +21,22 @@ namespace Tensorflow private readonly IntPtr _handle; private int _id; - //[JsonIgnore] + private Operation _op; +#if GRAPH_SERIALIZE + [JsonIgnore] + public int Id => _id; + [JsonIgnore] + public Graph graph => op?.graph; + [JsonIgnore] + public Operation op => _op; + [JsonIgnore] + public Tensor[] outputs => op.outputs; +#else public int Id => _id; - //[JsonIgnore] public Graph graph => op?.graph; - private Operation _op; - //[JsonIgnore] public Operation op => _op; - //[JsonIgnore] public Tensor[] outputs => op.outputs; +#endif /// /// The string name of this tensor. @@ -49,6 +58,11 @@ namespace Tensorflow private TF_Output? _tf_output; + /// + /// used for keep other pointer when do implicit operating + /// + public object Tag { get; set; } + public int[] shape { get @@ -210,11 +224,11 @@ namespace Tensorflow } } - public Tensor this[int slice_spec] + public Tensor this[Slice slice] { get { - var slice_spec_s = new int[] { slice_spec }; + var slice_spec = new int[] { slice.Start.Value }; var begin = new List(); var end = new List(); var strides = new List(); @@ -224,22 +238,27 @@ namespace Tensorflow var (begin_mask, end_mask) = (0, 0); var ellipsis_mask = 0; - foreach(var s in slice_spec_s) + foreach (var s in slice_spec) { + begin.Add(s); + if(slice.Stop.HasValue) + { + end.Add(slice.Stop.Value); + } + else { - begin.Add(s); - end.Add(s + 1); - strides.Add(1); - shrink_axis_mask |= (1 << index); + end.Add(0); + end_mask |= (1 << index); } - + strides.Add(slice.Step); + index += 1; } return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => { string name = scope; - if(begin != null) + if (begin != null) { var (packed_begin, packed_end, packed_strides) = (array_ops.stack(begin.ToArray()), @@ -256,13 +275,65 @@ namespace Tensorflow shrink_axis_mask: shrink_axis_mask, new_axis_mask: new_axis_mask, ellipsis_mask: ellipsis_mask, + + name: name); + } + + throw new NotImplementedException(""); + }); + } + } + + public Tensor this[int start] + { + get + { + var slice_spec = new int[] { start }; + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slice_spec) + { + begin.Add(s); + end.Add(s + 1); + strides.Add(1); + shrink_axis_mask |= (1 << index); + index += 1; + } + + return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + name: name); } throw new NotImplementedException(""); }); } - } public override string ToString() diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 444f384d..5b2fedc5 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -16,6 +16,8 @@ namespace Tensorflow { case TF_DataType.TF_BOOL: return typeof(bool); + case TF_DataType.TF_INT64: + return typeof(long); case TF_DataType.TF_INT32: return typeof(int); case TF_DataType.TF_INT16: diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index b6063234..06f51352 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -1,6 +1,9 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Framework; +using static Tensorflow.Python; namespace Tensorflow.Train { @@ -10,9 +13,10 @@ namespace Tensorflow.Train /// public class AdamOptimizer : Optimizer { - private float _beta1; - private float _beta2; - private float _epsilon; + float _beta1; + float _beta2; + float _epsilon; + Tensor _lr_t, _beta1_t, _beta2_t, _epsilon_t; public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") : base(learning_rate, use_locking, name) @@ -21,5 +25,79 @@ namespace Tensorflow.Train _beta2 = beta2; _epsilon = epsilon; } + + public override Operation _apply_sparse(IndexedSlices grad, RefVariable var) + { + return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) => + { + return state_ops.scatter_add(x, i, v, use_locking: _use_locking); + }); + } + + private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func scatter_add) + { + var (beta1_power_v, beta2_power_v) = _get_beta_accumulators(); + Tensor beta1_power = math_ops.cast(beta1_power_v, var.dtype.as_base_dtype()); + Tensor beta2_power = math_ops.cast(beta2_power_v, var.dtype.as_base_dtype()); + var lr_t = math_ops.cast(_lr_t, var.dtype.as_base_dtype()); + var beta1_t = math_ops.cast(_beta1_t, var.dtype.as_base_dtype()); + var beta2_t = math_ops.cast(_beta2_t, var.dtype.as_base_dtype()); + var epsilon_t = math_ops.cast(_epsilon_t, var.dtype.as_base_dtype()); + var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); + var m = get_slot(var, "m"); + var m_scaled_g_values = grad * (1 - beta1_t); + var mul = m * beta1_t; + var m_t = state_ops.assign(m, mul, use_locking: _use_locking); + with(ops.control_dependencies(new[] { m_t }), delegate + { + m_t = scatter_add(m, indices, m_scaled_g_values); + }); + + var v = get_slot(var, "v"); + var v_scaled_g_values = (grad * grad) * (1 - beta2_t); + var v_t = state_ops.assign(v, v * beta2_t, use_locking: _use_locking); + with(ops.control_dependencies(new[] { v_t }), delegate + { + v_t = scatter_add(v, indices, v_scaled_g_values); + }); + var v_sqrt = math_ops.sqrt(v_t); + var var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking: _use_locking); + return control_flow_ops.group(new[] { var_update, m_t, v_t }); + } + + protected override void _create_slots(RefVariable[] var_list) + { + var first_var = var_list.OrderBy(x => x.name).First(); + _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); + _create_non_slot_variable(initial_value: _beta2, name: "beta2_power", colocate_with: first_var); + + // Create slots for the first and second moments. + foreach(var v in var_list) + { + _zeros_slot(v, "m", Name); + _zeros_slot(v, "v", Name); + } + } + + private (RefVariable, RefVariable) _get_beta_accumulators() + { + ops.init_scope(); + var graph = ops.get_default_graph(); + return (_get_non_slot_variable("beta1_power", graph: graph), + _get_non_slot_variable("beta2_power", graph: graph)); + } + + public override void _prepare() + { + var lr = _call_if_callable(_lr); + var beta1 = _call_if_callable(_beta1); + var beta2 = _call_if_callable(_beta2); + var epsilon = _call_if_callable(_epsilon); + + _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + _beta1_t = ops.convert_to_tensor(beta1, name: "beta1"); + _beta2_t = ops.convert_to_tensor(beta2, name: "beta2"); + _epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon"); + } } } diff --git a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs index f545d859..d69228d6 100644 --- a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs @@ -26,14 +26,13 @@ namespace Tensorflow.Train public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") : base(learning_rate, use_locking, name) { - LearningRate = learning_rate; - LearningRateTensor = null; + _lr = learning_rate; } public override void _prepare() { - LearningRate = _call_if_callable(LearningRate); - LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate"); + var lr = _call_if_callable(_lr); + _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); } } } diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 3a14390d..c7a31b9d 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Framework; +using Tensorflow.Train; using static Tensorflow.Python; namespace Tensorflow @@ -12,32 +14,36 @@ namespace Tensorflow /// class directly, but instead instantiate one of its subclasses such as /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. /// - public abstract class Optimizer + public abstract class Optimizer : Trackable { // Values for gate_gradients. public static int GATE_NONE = 0; public static int GATE_OP = 1; public static int GATE_GRAPH = 2; - public string Name { get; set; } - public float LearningRate { get; set; } - public Tensor LearningRateTensor { get; set; } + string _name; + public string Name => _name; + protected float _lr; + public float LearningRate => _lr; + protected Tensor _lr_t; + public Tensor LearningRateTensor => _lr_t; public bool _use_locking; - public Dictionary _slots; - public Dictionary _non_slot_dict; + public Dictionary> _slots; + public Dictionary _non_slot_dict; public Dictionary _deferred_slot_restorations; + SlotCreator slot_creator = new SlotCreator(); public Optimizer(float learning_rate, bool use_locking, string name = null) { if (String.IsNullOrEmpty(name)) throw new NotImplementedException("Must specify the optimizer name"); - Name = name; + _name = name; _use_locking = use_locking; - LearningRate = learning_rate; + _lr = learning_rate; // Dictionary of slots. - _slots = new Dictionary(); - _non_slot_dict = new Dictionary(); + _slots = new Dictionary>(); + _non_slot_dict = new Dictionary(); _deferred_slot_restorations = new Dictionary(); } @@ -110,7 +116,7 @@ namespace Tensorflow public Operation apply_gradients(Tuple[] grads_and_vars, RefVariable global_step = null, string name = null) { // No DistributionStrategy case. - var converted_grads_and_vars = new List>(); + var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>(); foreach (var (g, v) in grads_and_vars) { if(g != null) @@ -118,7 +124,7 @@ namespace Tensorflow // Convert the grad to Tensor or IndexedSlices if necessary. var gR = ops.convert_to_tensor_or_indexed_slices(g); var p = _get_processor(v); - converted_grads_and_vars.Add(new Tuple(gR, v, p)); + converted_grads_and_vars.Add((gR, v, p)); } } @@ -143,7 +149,8 @@ namespace Tensorflow var scope_name = var.op.name; with(ops.name_scope("update_" + scope_name), scope2 => { - update_ops.Add(processor.update_op(this, grad)); + var op = processor.update_op(this, grad); + update_ops.Add(op); }); } @@ -185,9 +192,49 @@ namespace Tensorflow }); } - private void _create_slots(RefVariable[] var_list) + /// + /// Create the beta1 and beta2 accumulators on the same device as the first + /// variable. Sort the var_list to make sure this device is consistent across + /// workers (these need to go on the same PS, otherwise some updates are + /// silently ignored). + /// + /// + protected virtual void _create_slots(RefVariable[] var_list) { + + } + /// + /// Add an extra variable, not associated with a slot. + /// + /// + /// + /// + protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) + { + // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. + var graph = colocate_with.graph; + var key = $"{name}.{graph.graph_key}"; + var v = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; + if(v == null) + { + _maybe_initialize_trackable(); + v = variable_scope.default_variable_creator( + initial_value, name: name, trainable: false, + use_resource: resource_variable_ops.is_resource_variable( + colocate_with)); + + // Restore this variable by name if necessary, but don't add a + // Trackable dependency. Optimizers return the current graph's + // non-slot variables from _checkpoint_dependencies explicitly rather + // than unconditionally adding dependencies (since there may be multiple + // non-slot variables with the same name in different graphs, trying to + // save all of them would result in errors). + _handle_deferred_dependencies(name, v); + _non_slot_dict[key] = v; + } + + return v; } public virtual Operation _finish(Operation[] update_ops, string name_scope) @@ -201,11 +248,68 @@ namespace Tensorflow return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; } + /// + /// Add ops to apply sparse gradients to `var`, with repeated sparse indices. + /// + /// + /// + /// + public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, RefVariable var) + { + var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); + var gradient_no_duplicate_indices = new IndexedSlices( + indices: unique_indices, + values: summed_values, + dense_shape: grad.dense_shape); + return _apply_sparse(gradient_no_duplicate_indices, var); + } + + public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) + { + throw new NotImplementedException("_apply_sparse"); + } + + public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices) + { + var (unique_indices, new_index_positions) = array_ops.unique(indices); + var shape = array_ops.shape(unique_indices)[0]; + var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape); + return (summed_values, unique_indices); + } + public virtual void _prepare() { } + /// + /// Return a slot named `name` created for `var` by the Optimizer. + /// + /// + /// + /// + protected RefVariable get_slot(RefVariable var, string name) + { + var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; + if (named_slots == null) + return null; + + return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; + } + + private string _var_key(RefVariable var) + { + return $"{var.op.graph.graph_key}.{var.op.name}"; + } + + protected RefVariable _get_non_slot_variable(string name, Graph graph = null) + { + var key = $"{name}.{graph.graph_key}"; + var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; + + return non_slot; + } + private _OptimizableVariable _get_processor(RefVariable v) { if(v is RefVariable) @@ -282,5 +386,45 @@ namespace Tensorflow { return param; } + + /// + /// Find or create a slot initialized with 0.0. + /// + /// + /// + /// + /// + protected RefVariable _zeros_slot(RefVariable var, string slot_name, string op_name) + { + var named_slots = _slot_dict(slot_name); + if (!named_slots.ContainsKey(_var_key(var))) + { + var new_slot_variable = slot_creator.create_zeros_slot(var, op_name); + _restore_slot_variable(slot_name: slot_name, variable: var, slot_variable: new_slot_variable); + named_slots[_var_key(var)] = new_slot_variable; + } + return named_slots[_var_key(var)]; + } + + /// + /// Restore a newly created slot variable's value. + /// + protected void _restore_slot_variable(string slot_name, RefVariable variable, RefVariable slot_variable) + { + var variable_key = _var_key(variable); + // TODO + } + + protected Dictionary _slot_dict(string slot_name) + { + var named_slots = _slots.ContainsKey(slot_name) ? _slots[slot_name] : null; + if(named_slots == null) + { + named_slots = new Dictionary(); + _slots[slot_name] = named_slots; + } + + return named_slots; + } } } diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index f0db80ce..13886401 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -16,6 +16,12 @@ namespace Tensorflow _write_version = write_version; } + /// + /// Create an Op to save 'saveables'. + /// + /// + /// + /// public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) { var tensor_names = new List(); @@ -105,6 +111,10 @@ namespace Tensorflow } var graph = ops.get_default_graph(); + // Do some sanity checking on collections containing + // PartitionedVariables. If a saved collection has a PartitionedVariable, + // the GraphDef needs to include concat ops to get the value (or there'll + // be a lookup error on load). var check_collection_list = graph.get_all_collection_keys(); foreach (var collection_type in check_collection_list) { diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index 0f3a2ab8..b367eb1d 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -158,7 +158,10 @@ namespace Tensorflow string model_checkpoint_path = ""; string checkpoint_file = ""; - checkpoint_file = $"{save_path}-{global_step}"; + if (global_step > 0) + checkpoint_file = $"{save_path}-{global_step}"; + else + checkpoint_file = save_path; var save_path_parent = Path.GetDirectoryName(save_path); @@ -291,15 +294,13 @@ namespace Tensorflow if (_saver_def.MaxToKeep <= 0) return; // Remove first from list if the same name was used before. - foreach (var p in _last_checkpoints) - if (latest_save_path == _CheckpointFilename((p.Key, p.Value))) - _last_checkpoints.Remove(p.Key); - - // Append new path to list - _last_checkpoints.Add(latest_save_path, Python.time()); + var _existed_checkpoints = _last_checkpoints.FirstOrDefault(p => latest_save_path == _CheckpointFilename((p.Key, p.Value))); + if (_existed_checkpoints.Key != null) + _last_checkpoints.Remove(_existed_checkpoints.Key); + _last_checkpoints.Add(latest_save_path, time()); // If more than max_to_keep, remove oldest. - if(_last_checkpoints.Count > _saver_def.MaxToKeep) + if (_last_checkpoints.Count > _saver_def.MaxToKeep) { var first = _last_checkpoints.First(); _last_checkpoints.Remove(first.Key); diff --git a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs index 303f41a4..c21b954f 100644 --- a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs +++ b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs @@ -25,7 +25,7 @@ namespace Tensorflow var saver = _create_saver_from_imported_meta_graph( meta_graph_def, import_scope, imported_vars); - return (saver, null); + return (saver, imported_return_elements); } /// diff --git a/src/TensorFlowNET.Core/Train/SlotCreator.cs b/src/TensorFlowNET.Core/Train/SlotCreator.cs new file mode 100644 index 00000000..a666cae0 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/SlotCreator.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Initializers; +using static Tensorflow.Python; + +namespace Tensorflow.Train +{ + public class SlotCreator + { + /// + /// Create a slot initialized to 0 with same shape as the primary object. + /// + /// + /// + /// + /// + /// + public RefVariable create_zeros_slot(RefVariable primary, string name, TF_DataType dtype = TF_DataType.DtInvalid, bool colocate_with_primary = true) + { + if (dtype == TF_DataType.DtInvalid) + dtype = primary.dtype; + var slot_shape = primary.shape; + if (slot_shape.is_fully_defined()) + { + var initializer = new Zeros(); + return create_slot_with_initializer( + primary, initializer, slot_shape, dtype, name, + colocate_with_primary: colocate_with_primary); + } + else + { + throw new NotImplementedException("create_zeros_slot is not fully defined."); + } + } + + /// + /// Creates a slot initialized using an `Initializer`. + /// + /// + public RefVariable create_slot_with_initializer(RefVariable primary, IInitializer initializer, TensorShape shape, + TF_DataType dtype, string name, bool colocate_with_primary = true) + { + var validate_shape = shape.is_fully_defined(); + var prefix = primary.op.name; + return with(new variable_scope(string.Empty, prefix + "/" + name), delegate + { + return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype); + }); + } + + /// + /// Helper function for creating a slot variable. + /// + /// + /// + /// + /// + /// + /// + /// + private RefVariable _create_slot_var(VariableV1 primary, IInitializer val, string scope, bool validate_shape, + TensorShape shape, TF_DataType dtype) + { + bool use_resource = primary is ResourceVariable; + if (resource_variable_ops.is_resource_variable(primary)) + use_resource = true; + + var slot = tf.get_variable( + scope, + initializer: val, + trainable: false, + use_resource: use_resource, + shape: shape, + dtype: dtype, + validate_shape: validate_shape); + + return slot; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Trackable.cs b/src/TensorFlowNET.Core/Train/Trackable.cs index c16304a9..c98b2116 100644 --- a/src/TensorFlowNET.Core/Train/Trackable.cs +++ b/src/TensorFlowNET.Core/Train/Trackable.cs @@ -6,6 +6,8 @@ namespace Tensorflow.Train { public abstract class Trackable { + protected int _self_update_uid; + /// /// Restore-on-create for a variable be saved with this `Checkpointable`. /// @@ -32,9 +34,29 @@ namespace Tensorflow.Train return new_variable; } + /// + /// Pop and load any deferred checkpoint restores into `trackable`. + /// + /// + /// + protected void _handle_deferred_dependencies(string name, RefVariable trackable) + { + _maybe_initialize_trackable(); + // TODO + } + protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false) { return checkpointable; } + + /// + /// Initialize dependency management. + /// + protected void _maybe_initialize_trackable() + { + // _self_unconditional_checkpoint_dependencies = [] + _self_update_uid = -1; + } } } diff --git a/src/TensorFlowNET.Core/Train/_OptimizableVariable.cs b/src/TensorFlowNET.Core/Train/_OptimizableVariable.cs index e363e580..2d61781a 100644 --- a/src/TensorFlowNET.Core/Train/_OptimizableVariable.cs +++ b/src/TensorFlowNET.Core/Train/_OptimizableVariable.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Framework; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Train/optimizer.py.cs b/src/TensorFlowNET.Core/Train/optimizer.py.cs index 3a376e97..15c302b4 100644 --- a/src/TensorFlowNET.Core/Train/optimizer.py.cs +++ b/src/TensorFlowNET.Core/Train/optimizer.py.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Framework; namespace Tensorflow { @@ -28,7 +29,16 @@ namespace Tensorflow public Operation update_op(Optimizer optimizer, Tensor g) { - var update_op = optimizer._apply_dense(g, _v); + Operation update_op = null; + + if (g.Tag == null) + { + update_op = optimizer._apply_dense(g, _v); + } + else if (g.Tag is IndexedSlices) + { + return optimizer._apply_sparse_duplicate_indices(g, _v); + } return update_op; } diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index 60cd7777..d9816af3 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -37,6 +37,8 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, object initializer = null, // IInitializer or Tensor bool? trainable = null, + bool? use_resource = null, + bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation= VariableAggregation.None) { diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 97c6b912..420f929f 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -57,24 +57,24 @@ namespace Tensorflow if (initializer is IInitializer init) { return _get_single_variable(name: name, - shape: shape, - dtype: dtype, - initializer: init, - trainable: trainable, - validate_shape: validate_shape, - synchronization: synchronization, - aggregation: aggregation); + shape: shape, + dtype: dtype, + initializer: init, + trainable: trainable, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); } else if (initializer is Tensor tensor) { return _get_single_variable(name: name, - shape: shape, - dtype: dtype, - initializer: tensor, - trainable: trainable, - validate_shape: validate_shape, - synchronization: synchronization, - aggregation: aggregation); + shape: shape, + dtype: dtype, + initializer: tensor, + trainable: trainable, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); } else { @@ -141,7 +141,7 @@ namespace Tensorflow v = variable_scope.default_variable_creator(init_val, name: name, trainable: trainable, - dtype: TF_DataType.DtInvalid, + dtype: variable_dtype, validate_shape: validate_shape, synchronization: synchronization, aggregation: aggregation); diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 4b4237a0..a5a4ab69 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -97,6 +97,20 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); return _op.outputs[0]; } - + + /// + /// Adds sparse updates to a variable reference. + /// + /// + /// + /// + /// + /// + /// + public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); + return _op.outputs[0]; + } } } diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index aaa27e85..22894fe0 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -36,8 +36,8 @@ namespace Tensorflow validate_shape: validate_shape, use_locking: use_locking, name: name); - else - throw new NotImplementedException("state_ops.assign"); + throw new NotImplementedException("state_ops.assign"); + //return @ref.assign(value, name: name); } public static Tensor assign_sub(RefVariable @ref, @@ -72,5 +72,13 @@ namespace Tensorflow Tensor value, bool use_locking = false, string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + + public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.scatter_add(@ref, indices, updates, use_locking: use_locking, name: name); + + throw new NotImplementedException("scatter_add"); + } } } diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index c972ae99..b84196a3 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -104,7 +104,7 @@ namespace Tensorflow current_name_scope = ops.name_scope(name_scope); } - if (_name != null || _scope != null) + if (!string.IsNullOrEmpty(_name) || _scope != null) { var name_scope = _scope.name.Split('/').Last(); if (current_name_scope == null) @@ -270,7 +270,11 @@ namespace Tensorflow } // TODO for Switch/Case - public static RefVariable get_variable(string embeddingMatrix, double[,] initializer, bool use_resource) + public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool trainable = false, + bool validate_shape = true) { throw new NotImplementedException(); } diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index 21b14b72..77a78a66 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -1,14 +1,20 @@ -TensorFlow.NET pack all required libraries in architecture-specific assemblies folders per NuGet standard. +TensorFlow.NET pack all required libraries in architecture-specific assemblies folders per NuGet standard [Deprecated] . + +We changed to use `Microsoft.ML.TensorFlow.Redist` to maintain the TensorFlow library. + + + +### Download manually Here are some pre-built TensorFlow binaries you can use for each platform: - Linux - - CPU-only: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.13.1.tar.gz - - GPU-enabled: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.13.1.tar.gz -- Mac: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.13.1.tar.gz + - CPU-only: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz + - GPU-enabled: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.14.0.tar.gz +- Mac: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.14.0.tar.gz - Windows - - CPU-only: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.13.1.zip - - GPU-enabled: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-1.13.1.zip + - CPU-only: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.14.0.zip + - GPU-enabled: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-1.14.0.zip ### Run in Linux @@ -16,6 +22,15 @@ Here are some pre-built TensorFlow binaries you can use for each platform: Download Linux pre-built library and unzip `libtensorflow.so` and `libtensorflow_framework.so` into current running directory. +To run image recognition in Linux, please ensure some prerequisite libraries is install. + +```shell +sudo apt install libc6-dev +sudo apt install libgdiplus +``` + +More information about [System.Drawing on Linux](). + ### Run in Mac OS ### GPU Tensorflow for windows @@ -41,7 +56,7 @@ pacman -S git patch unzip 4. Install from local wheel file. -`pip install C:/tmp/tensorflow_pkg/tensorflow-1.13.0-cp36-cp36m-win_amd64.whl` +`pip install C:/tmp/tensorflow_pkg/tensorflow-1.14.0-cp36-cp36m-win_amd64.whl` ### Export more APIs diff --git a/tensorflowlib/runtimes/win-x64/native/tensorflow.dll b/tensorflowlib/runtimes/win-x64/native/tensorflow.dll index d4c2474c..e69de29b 100644 Binary files a/tensorflowlib/runtimes/win-x64/native/tensorflow.dll and b/tensorflowlib/runtimes/win-x64/native/tensorflow.dll differ diff --git a/test/KerasNET.Test/Keras.UnitTest.csproj b/test/KerasNET.Test/Keras.UnitTest.csproj index 89a5425c..1e9a2253 100644 --- a/test/KerasNET.Test/Keras.UnitTest.csproj +++ b/test/KerasNET.Test/Keras.UnitTest.csproj @@ -26,7 +26,7 @@ - + diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs index aaaf6865..3793027a 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs @@ -105,6 +105,8 @@ namespace TensorFlowNET.Examples.ImageProcess // Create a train saver that is used to restore values into an eval graph // when exporting models. var train_saver = tf.train.Saver(); + train_saver.save(sess, CHECKPOINT_NAME); + sw.Restart(); for (int i = 0; i < how_many_training_steps; i++) @@ -178,6 +180,7 @@ namespace TensorFlowNET.Examples.ImageProcess print($"Save final result to : {output_graph}"); save_graph_to_file(output_graph, class_count); File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys)); + return test_accuracy > 0.75f; }); } @@ -604,7 +607,7 @@ namespace TensorFlowNET.Examples.ImageProcess // download variables.data checkpoint file. url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip"; Web.Download(url, data_dir, "tfhub_modules.zip"); - Compress.UnZip(Path.Join(data_dir, "tfhub_modules.zip"), Path.Join(Path.GetTempPath(), "tfhub_modules")); + Compress.UnZip(Path.Join(data_dir, "tfhub_modules.zip"), "tfhub_modules"); // Prepare necessary directories that can be used during training Directory.CreateDirectory(summaries_dir); diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index e64e6df8..470399ff 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -17,6 +17,7 @@ + diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index b3ed9bb2..4caf3b58 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -5,12 +5,10 @@ using System.Diagnostics; using System.IO; using System.Linq; using System.Text; +using Newtonsoft.Json; using NumSharp; using Tensorflow; -using Tensorflow.Keras.Engine; using Tensorflow.Sessions; -using TensorFlowNET.Examples.Text.cnn_models; -using TensorFlowNET.Examples.TextClassification; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; @@ -59,10 +57,10 @@ namespace TensorFlowNET.Examples //int classes = y.Data().Distinct().Count(); //int samples = len / classes; int train_size = (int)Math.Round(len * (1 - test_size)); - var train_x = x[new Slice(stop: train_size), new Slice()]; - var valid_x = x[new Slice(start: train_size), new Slice()]; - var train_y = y[new Slice(stop: train_size)]; - var valid_y = y[new Slice(start: train_size)]; + train_x = x[new Slice(stop: train_size), new Slice()]; + valid_x = x[new Slice(start: train_size), new Slice()]; + train_y = y[new Slice(stop: train_size)]; + valid_y = y[new Slice(start: train_size)]; Console.WriteLine("\tDONE"); return (train_x, valid_x, train_y, valid_y); } @@ -137,7 +135,8 @@ namespace TensorFlowNET.Examples { // delete old cached file which contains errors Console.WriteLine("Discarding cached file: " + meta_path); - File.Delete(meta_path); + if(File.Exists(meta_path)) + File.Delete(meta_path); } var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); @@ -197,17 +196,17 @@ namespace TensorFlowNET.Examples var h_pool = tf.concat(pooled_outputs, 3); var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); - + Tensor h_drop = null; with(tf.name_scope("dropout"), delegate { - var h_drop = tf.nn.dropout(h_pool_flat, keep_prob); + h_drop = tf.nn.dropout(h_pool_flat, keep_prob); }); Tensor logits = null; Tensor predictions = null; with(tf.name_scope("output"), delegate { - logits = tf.layers.dense(h_pool_flat, NUM_CLASS); + logits = tf.layers.dense(h_drop, NUM_CLASS); predictions = tf.argmax(logits, -1, output_type: tf.int32); }); diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs index 214bb835..a267324e 100644 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs @@ -91,6 +91,15 @@ namespace TensorFlowNET.ExamplesTests new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run(); } + + [TestMethod] + public void CnnTextClassificationTrain() + { + tf.Graph().as_default(); + new CnnTextClassification() { Enabled = true, IsImportingGraph = false }.Run(); + } + + [Ignore] [TestMethod] public void TextClassificationWithMovieReviews() diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 5e619165..f76ca132 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -16,7 +16,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs index 0e95fdc8..1b81fc21 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs @@ -16,8 +16,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test public void testResourceReadInLoop() { - var embedding_matrix = variable_scope.get_variable( - "embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); + //var embedding_matrix = variable_scope.get_variable( + //"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); Tensor cond(Tensor it, Tensor _) { diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs index 4b4623dc..744e52c3 100644 --- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs +++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs @@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.nn_test public void testZeroFraction() { var x_shape = new Shape(5, 17); - var x_np = new NumPyRandom().randint(0, 2, x_shape); + var x_np = np.random.randint(0, 2, x_shape); x_np.astype(np.float32); var y_np = this._ZeroFraction(x_np);