Browse Source

length issue in _ShapesFullySpecifiedAndEqual

tags/v0.12
Oceania2018 6 years ago
parent
commit
740c2db003
1 changed files with 34 additions and 1 deletions
  1. +34
    -1
      src/TensorFlowNET.Core/Gradients/math_grad.cs

+ 34
- 1
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -96,6 +96,13 @@ namespace Tensorflow.Gradients
});
}

[RegisterGradient("GreaterEqual")]
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
throw new NotImplementedException("_GreaterEqualGrad");
}

[RegisterGradient("Identity")]
public static Tensor[] _IdGrad(Operation op, Tensor[] grads)
{
@@ -124,6 +131,17 @@ namespace Tensorflow.Gradients
});
}

[RegisterGradient("Log1p")]
public static Tensor[] _Log1pGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var x = op.inputs[0];
return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => {
x = math_ops.conj(x);
return new Tensor[] { grad * math_ops.reciprocal(1 + x) };
});
}

[RegisterGradient("Mul")]
public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
{
@@ -332,6 +350,21 @@ namespace Tensorflow.Gradients
return new Tensor[] { -grads[0] };
}

[RegisterGradient("Select")]
public static Tensor[] _SelectGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var c = op.inputs[0];
var x = op.inputs[1];
var zeros = array_ops.zeros_like(x);
return new Tensor[]
{
null,
array_ops.where(c, grad, zeros),
array_ops.where(c, zeros, grad)
};
}

private static Tensor _safe_shape_div(Tensor x, Tensor y)
{
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1));
@@ -363,7 +396,7 @@ namespace Tensorflow.Gradients
var grad_shape = grad._shape_tuple();
return Enumerable.SequenceEqual(x_shape, y_shape) &&
Enumerable.SequenceEqual(y_shape, grad_shape) &&
x.NDims != -1 &&
x_shape.Length > -1 &&
!x_shape.Contains(-1);
}



Loading…
Cancel
Save