From 14c26e7e070e038768410320e515c8de3f3d9d07 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 20 Dec 2020 17:41:03 -0600 Subject: [PATCH] Mul gradient is not correct in TensorFlowOpLayer #698 --- .../Eager/EagerRunner.TFE_FastPathExecute.cs | 3 ++- .../Functions/TapeGradientFunctions.cs | 24 +++++++---------- .../Gradients/array_grad.cs | 7 +++-- src/TensorFlowNET.Core/Gradients/math_grad.cs | 12 ++++----- .../Operations/array_ops.cs | 9 +++++++ .../Operations/gen_array_ops.cs | 2 +- .../Operations/gen_math_ops.cs | 26 +++++++++---------- .../Tensors/TensorShape.Equals.cs | 4 +++ src/TensorFlowNET.Keras/BackendImpl.cs | 2 +- .../Layers/Reshaping/Reshape.cs | 3 ++- 10 files changed, 53 insertions(+), 39 deletions(-) diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index c703aaaa..5fe058e2 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -380,7 +380,8 @@ namespace Tensorflow.Eager c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value)); break; case TF_AttrType.TF_ATTR_INT: - c_api.TFE_OpSetAttrInt(op, key, Convert.ToInt64(value)); + attr_list_sizes[key] = Convert.ToInt64(value); + c_api.TFE_OpSetAttrInt(op, key, attr_list_sizes[key]); break; case TF_AttrType.TF_ATTR_FLOAT: c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 4cd59c92..19ae5fbc 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -44,7 +44,7 @@ namespace Tensorflow.Functions public void Record(Tensors flat_outputs, Tensors inference_args) { var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); - tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args, + tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, getBackwardFunction: () => backward_function); } @@ -52,20 +52,16 @@ namespace Tensorflow.Functions { BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => { - return new Tensor[0]; - - /*var gradients = ops.gradientFunctions[op_name](new EagerOperation + var processed_args = new List(); + var input_index = 0; + foreach (var (output_index, arg) in enumerate(output_grads)) { - Name = op_name, - NumInputs = op_inputs.Length, - Inputs = op_inputs, - NumOutputs = op_outputs.Length, - Outputs = op_outputs, - SkipInputIndices = unneeded_gradients, - Attrs = attrs - }, output_grads); - - return gradients;*/ + if (arg is null) + throw new NotImplementedException(""); + processed_args.add(arg); + input_index += 1; + } + return output_grads;// backward.Invoke(processed_args.ToArray()); }; return (_backward_function_wrapper, flat_outputs); diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 8df7b970..0fe61bd2 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -85,10 +85,13 @@ namespace Tensorflow.Gradients var out_grads = new List(); if(concat_dim is EagerTensor) { - var non_neg_concat_dim = (int)concat_dim % input_values[0].rank; + var dim_int = (int)concat_dim; + var non_neg_concat_dim = dim_int < 0 + ? input_values[0].rank + dim_int + : dim_int % input_values[0].rank; var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); var sizes_tensor = constant_op.constant(sizes); - out_grads = gen_array_ops.split_v(grad, sizes_tensor, sizes[0], non_neg_concat_dim).ToList(); + out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList(); } else if (constant_op.is_constant(concat_dim)) { diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 424493c6..9ea40816 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -212,7 +212,7 @@ namespace Tensorflow.Gradients }; } - var broads = SmartBroadcastGradientArgs(x, y); + var broads = SmartBroadcastGradientArgs(x, y, grad); var (sx, rx, must_reduce_x) = broads[0]; var (sy, ry, must_reduce_y) = broads[1]; @@ -468,7 +468,7 @@ namespace Tensorflow.Gradients _ShapesFullySpecifiedAndEqual(x, y, grad)) return new Tensor[] { grad, -grad }; - var broads = SmartBroadcastGradientArgs(x, y); + var broads = SmartBroadcastGradientArgs(x, y, grad); var (sx, rx, must_reduce_x) = broads[0]; var (sy, ry, must_reduce_y) = broads[1]; @@ -718,7 +718,7 @@ namespace Tensorflow.Gradients var z = op.outputs[0]; - var broads = SmartBroadcastGradientArgs(x, y); + var broads = SmartBroadcastGradientArgs(x, y, grad); var (sx, rx, must_reduce_x) = broads[0]; var (sy, ry, must_reduce_y) = broads[1]; @@ -753,7 +753,7 @@ namespace Tensorflow.Gradients /// /// /// - private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y) + private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad) { Tensor sx, sy; if (x.TensorShape.is_fully_defined() && @@ -771,8 +771,8 @@ namespace Tensorflow.Gradients var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); return new[] { - (sx, rx, true), - (sy, ry, true) + (sx, rx, !x.TensorShape.Equals(grad.TensorShape)), + (sy, ry, !y.TensorShape.Equals(grad.TensorShape)) }; } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index cd862df3..1de84664 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -885,6 +885,15 @@ namespace Tensorflow }); } + public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1, + string name = "split") + { + if (num == -1) + num = size_splits.shape[0]; + + return gen_array_ops.split_v(value, size_splits, axis, num, name: name); + } + public static Tensor[] split(Tensor value, int num_split, T axis, string name = "split") { diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 27a4e5da..d56813f4 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -527,7 +527,7 @@ namespace Tensorflow var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "SplitV", name, null, - value, size_splits, axis, + value, size_splits, axis, "num_split", num_split); return results; diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index ca74ea5f..b40dc2ae 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -346,21 +346,21 @@ namespace Tensorflow /// dy is the corresponding input gradient. /// public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad") - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("SigmoidGrad", name, new { y, dy }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "SigmoidGrad", name, null, - y, dy); - - return results[0]; - } - - var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy }); - - return op.output; - } + y, dy).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T") + }; + tf.Runner.RecordGradient("SigmoidGrad", op.inputs, attrs, op.outputs); + }, + new Tensors(y, dy)); public static Tensor sign(T x, string name = "Sign") { diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs index 7f23cc58..9078dbed 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs @@ -10,6 +10,10 @@ namespace Tensorflow switch (obj) { case TensorShape shape1: + if (rank == -1 && shape1.rank == -1) + return false; + else if (rank != shape1.rank) + return false; return Enumerable.SequenceEqual(shape1.dims, dims); default: return false; diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index a55791d2..d11de338 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -239,7 +239,7 @@ namespace Tensorflow.Keras { var rank = tensors[0].NDims; if (rank > -1) - axis %= rank; + axis += rank; else axis = 0; } diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index 28c7be3e..dcf27c24 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -21,7 +21,8 @@ namespace Tensorflow.Keras.Layers protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { - var shape = new List { inputs.shape[0] }; + var shape_tensor = array_ops.shape(inputs); + var shape = new List { shape_tensor.shape[0] }; shape.AddRange(args.TargetShape.dims); var result = array_ops.reshape(inputs, shape.ToArray());