diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 9054e168..b082a720 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -64,6 +64,8 @@ namespace Tensorflow // Get a uid for this call to gradients that can be used to help // cluster ops for compilation. var gradient_uid = ops.get_default_graph().unique_name("uid"); + ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); + xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); /** @@ -148,7 +150,7 @@ namespace Tensorflow } else { - in_grads = _NonEagerInputs(op, xs).Select(x => new Tensor(IntPtr.Zero)).ToArray(); + in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; } var inputs = _NonEagerInputs(op, xs).ToList(); @@ -226,6 +228,7 @@ namespace Tensorflow private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func grad_fn) { + scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; return grad_fn(op, out_grads); } diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.py.cs b/src/TensorFlowNET.Core/Gradients/math_grad.py.cs index d8b0ccf9..e72c662d 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.py.cs @@ -72,9 +72,10 @@ namespace Tensorflow public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) { - return false; - /*return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) && - string.Join(",", x.shape).Equals(string.Join(",", grad.shape));*/ + if (x.NDims == 0 && y.NDims == 0 && grad.NDims == 0) return true; + + return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) && + string.Join(",", x.shape).Equals(string.Join(",", grad.shape)); } public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 469fb45a..30d528df 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -41,17 +41,20 @@ namespace Tensorflow _graph_key = $"grap-key-{ops.uid()}/"; } - public T as_graph_element(T obj, bool allow_tensor = true, bool allow_operation = true) + public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) { return _as_graph_element_locked(obj, allow_tensor, allow_operation); } - private Func _as_graph_element(object obj) + private Tensor _as_graph_element(object obj) { + if (obj is RefVariable var) + return var._as_graph_element(); + return null; } - private T _as_graph_element_locked(T obj, bool allow_tensor = true, bool allow_operation = true) + private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) { string types_str = ""; @@ -69,12 +72,14 @@ namespace Tensorflow } var temp_obj = _as_graph_element(obj); + if (temp_obj != null) + obj = temp_obj; if (obj is Tensor tensor && allow_tensor) { if (tensor.Graph.Equals(this)) { - return obj; + return tensor; } else { @@ -85,7 +90,7 @@ namespace Tensorflow { if (op.Graph.Equals(this)) { - return obj; + return op; } else { @@ -93,7 +98,7 @@ namespace Tensorflow } } - throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); + throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); } public void add_to_collection(string name, T value) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 0b6186f2..13234366 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -35,6 +35,11 @@ namespace Tensorflow c_api.TF_DeleteSessionOptions(opts); } + public virtual NDArray run(RefVariable fetches, FeedItem[] feed_dict = null) + { + return _run(fetches, feed_dict); + } + public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null) { return _run(fetches, feed_dict); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 7b97cf11..52d5f056 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -109,20 +109,13 @@ namespace Tensorflow }); } - public Tensor _ref() - { - return _variable; - } + public Tensor _ref() => _variable; - public Tensor value() - { - return _snapshot; - } + public Tensor value() => _snapshot; - public Tensor _AsTensor() - { - return _snapshot; - } + public Tensor _AsTensor() => _snapshot; + + public Tensor _as_graph_element() => _variable; public Tensor _TensorConversionFunction(bool as_ref = false) { diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 0c0ccfc4..74c8a5f7 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -329,6 +329,11 @@ namespace Tensorflow }; } + public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") + { + return internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name); + } + public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") { return internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false); @@ -339,6 +344,26 @@ namespace Tensorflow return value; } + public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false) + { + var ret = new List(); + + foreach(var (i, value) in Python.enumerate(values)) + { + if (value == null) + { + ret.Add(value); + } + else + { + var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}"; + ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref)); + } + } + + return ret.ToArray(); + } + public static Tensor[] internal_convert_n_to_tensor(T[] values, DataType dtype = DataType.DtInvalid, string name = "", DataType preferred_dtype = DataType.DtInvalid, bool as_ref = false) diff --git a/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll b/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll index 4c62c8ce..583c710f 100644 Binary files a/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll and b/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll differ diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index d9fb7702..79db1ed9 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -80,9 +80,9 @@ namespace TensorFlowNET.Examples new FeedItem(X, train_X), new FeedItem(Y, train_Y) }); - + var rW = sess.run(W); Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + - $"W={sess.run(W)} b={sess.run(b)}"); + $"W={rW} b={sess.run(b)}"); } } diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs index 6b937425..bc887764 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest.cs @@ -23,8 +23,7 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(ys.op.inputs[0].name, "Const:0"); Assert.AreEqual(ys.op.inputs[1].name, "mul:0"); - var xs = new Tensor[] { a, b }; - var g = tf.gradients(ys, xs, stop_gradients: new Tensor[] { a, b }); + var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b }); Assert.AreEqual(g[0].name, "gradients/Fill:0"); Assert.AreEqual(g[1].name, "gradients/Fill:0"); }