diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 1be2970a..f18ece8c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -203,6 +203,9 @@ namespace Tensorflow public static Tensor sqrt(Tensor a, string name = null) => gen_math_ops.sqrt(a, name); + public static Tensor sign(Tensor a, string name = null) + => gen_math_ops.sign(a, name); + public static Tensor subtract(Tensor x, T[] y, string name = null) where T : struct => gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name); diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index e6fd68e7..a84185f3 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -16,6 +16,7 @@ using System; using System.Linq; +using Tensorflow.Operations; using static Tensorflow.Python; namespace Tensorflow.Gradients @@ -26,6 +27,15 @@ namespace Tensorflow.Gradients [RegisterGradient("math_grad")] public class math_grad { + [RegisterGradient("Abs")] + public static Tensor[] _AbsGrad(Operation op, Tensor[] grads) + { + var x = op.inputs[0]; + var grad = grads[0]; + + return new Tensor[] { gen_ops.mul(grad, gen_math_ops.sign(x)) }; + } + [RegisterGradient("Add")] public static Tensor[] _AddGrad(Operation op, Tensor[] grads) { @@ -428,6 +438,15 @@ namespace Tensorflow.Gradients }); } + [RegisterGradient("Sign")] + public static Tensor[] _SignGrad(Operation op, Tensor[] grads) + { + var x = op.inputs[0]; + var zero = constant_op.constant(0.0f, x.dtype, x.shape); + + return new Tensor[] {zero}; + } + [RegisterGradient("Square")] public static Tensor[] _SquareGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index dca38a9a..0999ad59 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -210,6 +210,13 @@ namespace Tensorflow return op.outputs[0]; } + public static Tensor sign(Tensor x, string name = "Sign") + { + var op = _op_def_lib._apply_op_helper("Sign", name: name, args: new {x}); + + return op.outputs[0]; + } + public static Tensor sinh(Tensor x, string name = null) { var _op = _op_def_lib._apply_op_helper("Sinh", name, args: new { x }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 9410b1b0..486fea8d 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -216,6 +216,15 @@ namespace Tensorflow return gen_math_ops.sigmoid(x_tensor, name: name); } + public static Tensor sign(Tensor x, string name = null) + { + return with(ops.name_scope(name, "Sign", new {x}), scope => + { + x = ops.convert_to_tensor(x, name: "x"); + return gen_math_ops.sign(x); + }); + } + /// /// Returns (x - y)(x - y) element-wise. /// diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 87cedc59..027465bf 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -72,5 +72,10 @@ namespace Tensorflow { return new Session(graph); } + + public static Session Session(SessionOptions opts) + { + return new Session(null, opts); + } } }