diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 672270b9..85f849cc 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -96,6 +96,13 @@ namespace Tensorflow.Gradients }); } + [RegisterGradient("GreaterEqual")] + public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + throw new NotImplementedException("_GreaterEqualGrad"); + } + [RegisterGradient("Identity")] public static Tensor[] _IdGrad(Operation op, Tensor[] grads) { @@ -124,6 +131,17 @@ namespace Tensorflow.Gradients }); } + [RegisterGradient("Log1p")] + public static Tensor[] _Log1pGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => { + x = math_ops.conj(x); + return new Tensor[] { grad * math_ops.reciprocal(1 + x) }; + }); + } + [RegisterGradient("Mul")] public static Tensor[] _MulGrad(Operation op, Tensor[] grads) { @@ -332,6 +350,21 @@ namespace Tensorflow.Gradients return new Tensor[] { -grads[0] }; } + [RegisterGradient("Select")] + public static Tensor[] _SelectGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var c = op.inputs[0]; + var x = op.inputs[1]; + var zeros = array_ops.zeros_like(x); + return new Tensor[] + { + null, + array_ops.where(c, grad, zeros), + array_ops.where(c, zeros, grad) + }; + } + private static Tensor _safe_shape_div(Tensor x, Tensor y) { return math_ops.floordiv(x, gen_math_ops.maximum(y, 1)); @@ -363,7 +396,7 @@ namespace Tensorflow.Gradients var grad_shape = grad._shape_tuple(); return Enumerable.SequenceEqual(x_shape, y_shape) && Enumerable.SequenceEqual(y_shape, grad_shape) && - x.NDims != -1 && + x_shape.Length > -1 && !x_shape.Contains(-1); }