| @@ -64,6 +64,8 @@ namespace Tensorflow | |||||
| // Get a uid for this call to gradients that can be used to help | // Get a uid for this call to gradients that can be used to help | ||||
| // cluster ops for compilation. | // cluster ops for compilation. | ||||
| var gradient_uid = ops.get_default_graph().unique_name("uid"); | var gradient_uid = ops.get_default_graph().unique_name("uid"); | ||||
| ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); | |||||
| xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); | |||||
| grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); | grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); | ||||
| /** | /** | ||||
| @@ -148,7 +150,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| in_grads = _NonEagerInputs(op, xs).Select(x => new Tensor(IntPtr.Zero)).ToArray(); | |||||
| in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | |||||
| } | } | ||||
| var inputs = _NonEagerInputs(op, xs).ToList(); | var inputs = _NonEagerInputs(op, xs).ToList(); | ||||
| @@ -226,6 +228,7 @@ namespace Tensorflow | |||||
| private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func<Operation, Tensor, Tensor[]> grad_fn) | private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func<Operation, Tensor, Tensor[]> grad_fn) | ||||
| { | { | ||||
| scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | |||||
| return grad_fn(op, out_grads); | return grad_fn(op, out_grads); | ||||
| } | } | ||||
| @@ -72,9 +72,10 @@ namespace Tensorflow | |||||
| public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) | public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) | ||||
| { | { | ||||
| return false; | |||||
| /*return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) && | |||||
| string.Join(",", x.shape).Equals(string.Join(",", grad.shape));*/ | |||||
| if (x.NDims == 0 && y.NDims == 0 && grad.NDims == 0) return true; | |||||
| return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) && | |||||
| string.Join(",", x.shape).Equals(string.Join(",", grad.shape)); | |||||
| } | } | ||||
| public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | ||||
| @@ -41,17 +41,20 @@ namespace Tensorflow | |||||
| _graph_key = $"grap-key-{ops.uid()}/"; | _graph_key = $"grap-key-{ops.uid()}/"; | ||||
| } | } | ||||
| public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| { | { | ||||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
| } | } | ||||
| private Func<object> _as_graph_element(object obj) | |||||
| private Tensor _as_graph_element(object obj) | |||||
| { | { | ||||
| if (obj is RefVariable var) | |||||
| return var._as_graph_element(); | |||||
| return null; | return null; | ||||
| } | } | ||||
| private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| { | { | ||||
| string types_str = ""; | string types_str = ""; | ||||
| @@ -69,12 +72,14 @@ namespace Tensorflow | |||||
| } | } | ||||
| var temp_obj = _as_graph_element(obj); | var temp_obj = _as_graph_element(obj); | ||||
| if (temp_obj != null) | |||||
| obj = temp_obj; | |||||
| if (obj is Tensor tensor && allow_tensor) | if (obj is Tensor tensor && allow_tensor) | ||||
| { | { | ||||
| if (tensor.Graph.Equals(this)) | if (tensor.Graph.Equals(this)) | ||||
| { | { | ||||
| return obj; | |||||
| return tensor; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -85,7 +90,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (op.Graph.Equals(this)) | if (op.Graph.Equals(this)) | ||||
| { | { | ||||
| return obj; | |||||
| return op; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -93,7 +98,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | |||||
| throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); | |||||
| } | } | ||||
| public void add_to_collection<T>(string name, T value) | public void add_to_collection<T>(string name, T value) | ||||
| @@ -35,6 +35,11 @@ namespace Tensorflow | |||||
| c_api.TF_DeleteSessionOptions(opts); | c_api.TF_DeleteSessionOptions(opts); | ||||
| } | } | ||||
| public virtual NDArray run(RefVariable fetches, FeedItem[] feed_dict = null) | |||||
| { | |||||
| return _run(fetches, feed_dict); | |||||
| } | |||||
| public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null) | public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null) | ||||
| { | { | ||||
| return _run(fetches, feed_dict); | return _run(fetches, feed_dict); | ||||
| @@ -109,20 +109,13 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public Tensor _ref() | |||||
| { | |||||
| return _variable; | |||||
| } | |||||
| public Tensor _ref() => _variable; | |||||
| public Tensor value() | |||||
| { | |||||
| return _snapshot; | |||||
| } | |||||
| public Tensor value() => _snapshot; | |||||
| public Tensor _AsTensor() | |||||
| { | |||||
| return _snapshot; | |||||
| } | |||||
| public Tensor _AsTensor() => _snapshot; | |||||
| public Tensor _as_graph_element() => _variable; | |||||
| public Tensor _TensorConversionFunction(bool as_ref = false) | public Tensor _TensorConversionFunction(bool as_ref = false) | ||||
| { | { | ||||
| @@ -329,6 +329,11 @@ namespace Tensorflow | |||||
| }; | }; | ||||
| } | } | ||||
| public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||||
| { | |||||
| return internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name); | |||||
| } | |||||
| public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | ||||
| { | { | ||||
| return internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false); | return internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false); | ||||
| @@ -339,6 +344,26 @@ namespace Tensorflow | |||||
| return value; | return value; | ||||
| } | } | ||||
| public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false) | |||||
| { | |||||
| var ret = new List<Tensor>(); | |||||
| foreach(var (i, value) in Python.enumerate(values)) | |||||
| { | |||||
| if (value == null) | |||||
| { | |||||
| ret.Add(value); | |||||
| } | |||||
| else | |||||
| { | |||||
| var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}"; | |||||
| ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref)); | |||||
| } | |||||
| } | |||||
| return ret.ToArray(); | |||||
| } | |||||
| public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid, | public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid, | ||||
| string name = "", DataType preferred_dtype = DataType.DtInvalid, | string name = "", DataType preferred_dtype = DataType.DtInvalid, | ||||
| bool as_ref = false) | bool as_ref = false) | ||||
| @@ -80,9 +80,9 @@ namespace TensorFlowNET.Examples | |||||
| new FeedItem(X, train_X), | new FeedItem(X, train_X), | ||||
| new FeedItem(Y, train_Y) | new FeedItem(Y, train_Y) | ||||
| }); | }); | ||||
| var rW = sess.run(W); | |||||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + | Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + | ||||
| $"W={sess.run(W)} b={sess.run(b)}"); | |||||
| $"W={rW} b={sess.run(b)}"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -23,8 +23,7 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.AreEqual(ys.op.inputs[0].name, "Const:0"); | Assert.AreEqual(ys.op.inputs[0].name, "Const:0"); | ||||
| Assert.AreEqual(ys.op.inputs[1].name, "mul:0"); | Assert.AreEqual(ys.op.inputs[1].name, "mul:0"); | ||||
| var xs = new Tensor[] { a, b }; | |||||
| var g = tf.gradients(ys, xs, stop_gradients: new Tensor[] { a, b }); | |||||
| var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b }); | |||||
| Assert.AreEqual(g[0].name, "gradients/Fill:0"); | Assert.AreEqual(g[0].name, "gradients/Fill:0"); | ||||
| Assert.AreEqual(g[1].name, "gradients/Fill:0"); | Assert.AreEqual(g[1].name, "gradients/Fill:0"); | ||||
| } | } | ||||