|
|
|
@@ -96,6 +96,13 @@ namespace Tensorflow.Gradients |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
[RegisterGradient("GreaterEqual")] |
|
|
|
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) |
|
|
|
{ |
|
|
|
var grad = grads[0]; |
|
|
|
throw new NotImplementedException("_GreaterEqualGrad"); |
|
|
|
} |
|
|
|
|
|
|
|
[RegisterGradient("Identity")] |
|
|
|
public static Tensor[] _IdGrad(Operation op, Tensor[] grads) |
|
|
|
{ |
|
|
|
@@ -124,6 +131,17 @@ namespace Tensorflow.Gradients |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
[RegisterGradient("Log1p")] |
|
|
|
public static Tensor[] _Log1pGrad(Operation op, Tensor[] grads) |
|
|
|
{ |
|
|
|
var grad = grads[0]; |
|
|
|
var x = op.inputs[0]; |
|
|
|
return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => { |
|
|
|
x = math_ops.conj(x); |
|
|
|
return new Tensor[] { grad * math_ops.reciprocal(1 + x) }; |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
[RegisterGradient("Mul")] |
|
|
|
public static Tensor[] _MulGrad(Operation op, Tensor[] grads) |
|
|
|
{ |
|
|
|
@@ -332,6 +350,21 @@ namespace Tensorflow.Gradients |
|
|
|
return new Tensor[] { -grads[0] }; |
|
|
|
} |
|
|
|
|
|
|
|
[RegisterGradient("Select")] |
|
|
|
public static Tensor[] _SelectGrad(Operation op, Tensor[] grads) |
|
|
|
{ |
|
|
|
var grad = grads[0]; |
|
|
|
var c = op.inputs[0]; |
|
|
|
var x = op.inputs[1]; |
|
|
|
var zeros = array_ops.zeros_like(x); |
|
|
|
return new Tensor[] |
|
|
|
{ |
|
|
|
null, |
|
|
|
array_ops.where(c, grad, zeros), |
|
|
|
array_ops.where(c, zeros, grad) |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
private static Tensor _safe_shape_div(Tensor x, Tensor y) |
|
|
|
{ |
|
|
|
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1)); |
|
|
|
@@ -363,7 +396,7 @@ namespace Tensorflow.Gradients |
|
|
|
var grad_shape = grad._shape_tuple(); |
|
|
|
return Enumerable.SequenceEqual(x_shape, y_shape) && |
|
|
|
Enumerable.SequenceEqual(y_shape, grad_shape) && |
|
|
|
x.NDims != -1 && |
|
|
|
x_shape.Length > -1 && |
|
|
|
!x_shape.Contains(-1); |
|
|
|
} |
|
|
|
|
|
|
|
|