|
|
@@ -212,8 +212,9 @@ namespace Tensorflow.Gradients |
|
|
}; |
|
|
}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var (sx, sy) = SmartBroadcastGradientArgs(x, y); |
|
|
|
|
|
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); |
|
|
|
|
|
|
|
|
var broads = SmartBroadcastGradientArgs(x, y); |
|
|
|
|
|
var (sx, rx, must_reduce_x) = broads[0]; |
|
|
|
|
|
var (sy, ry, must_reduce_y) = broads[1]; |
|
|
|
|
|
|
|
|
x = math_ops.conj(x); |
|
|
x = math_ops.conj(x); |
|
|
y = math_ops.conj(y); |
|
|
y = math_ops.conj(y); |
|
|
@@ -222,33 +223,21 @@ namespace Tensorflow.Gradients |
|
|
|
|
|
|
|
|
if (op is EagerOperation op_eager1 && |
|
|
if (op is EagerOperation op_eager1 && |
|
|
op_eager1.SkipInputIndices.Contains(0)) |
|
|
op_eager1.SkipInputIndices.Contains(0)) |
|
|
{ |
|
|
|
|
|
return new Tensor[] |
|
|
|
|
|
{ |
|
|
|
|
|
gen_math_ops.mul(grad, math_ops.conj(y)), |
|
|
|
|
|
null |
|
|
|
|
|
}; |
|
|
|
|
|
} |
|
|
|
|
|
// else if not must_reduce_x: |
|
|
|
|
|
// gx = gen_math_ops.mul(grad, y) |
|
|
|
|
|
|
|
|
gy = null; |
|
|
|
|
|
else if (!must_reduce_x) |
|
|
|
|
|
gx = gen_math_ops.mul(grad, y); |
|
|
else |
|
|
else |
|
|
{ |
|
|
|
|
|
gx = array_ops.reshape( |
|
|
gx = array_ops.reshape( |
|
|
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx); |
|
|
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (op is EagerOperation op_eager2 && |
|
|
if (op is EagerOperation op_eager2 && |
|
|
op_eager2.SkipInputIndices.Contains(1)) |
|
|
op_eager2.SkipInputIndices.Contains(1)) |
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
// else if not must_reduce_y: |
|
|
|
|
|
// gy = gen_math_ops.mul(x, grad) |
|
|
|
|
|
|
|
|
gy = null; |
|
|
|
|
|
else if (!must_reduce_y) |
|
|
|
|
|
gy = gen_math_ops.mul(x, grad); |
|
|
else |
|
|
else |
|
|
{ |
|
|
|
|
|
gy = array_ops.reshape( |
|
|
gy = array_ops.reshape( |
|
|
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy); |
|
|
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return new Tensor[] { gx, gy }; |
|
|
return new Tensor[] { gx, gy }; |
|
|
} |
|
|
} |
|
|
@@ -479,8 +468,9 @@ namespace Tensorflow.Gradients |
|
|
_ShapesFullySpecifiedAndEqual(x, y, grad)) |
|
|
_ShapesFullySpecifiedAndEqual(x, y, grad)) |
|
|
return new Tensor[] { grad, -grad }; |
|
|
return new Tensor[] { grad, -grad }; |
|
|
|
|
|
|
|
|
var (sx, sy) = SmartBroadcastGradientArgs(x, y); |
|
|
|
|
|
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); |
|
|
|
|
|
|
|
|
var broads = SmartBroadcastGradientArgs(x, y); |
|
|
|
|
|
var (sx, rx, must_reduce_x) = broads[0]; |
|
|
|
|
|
var (sy, ry, must_reduce_y) = broads[1]; |
|
|
|
|
|
|
|
|
var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); |
|
|
var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); |
|
|
var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy); |
|
|
var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy); |
|
|
@@ -728,8 +718,10 @@ namespace Tensorflow.Gradients |
|
|
|
|
|
|
|
|
var z = op.outputs[0]; |
|
|
var z = op.outputs[0]; |
|
|
|
|
|
|
|
|
var (sx, sy) = SmartBroadcastGradientArgs(x, y); |
|
|
|
|
|
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); |
|
|
|
|
|
|
|
|
var broads = SmartBroadcastGradientArgs(x, y); |
|
|
|
|
|
var (sx, rx, must_reduce_x) = broads[0]; |
|
|
|
|
|
var (sy, ry, must_reduce_y) = broads[1]; |
|
|
|
|
|
|
|
|
x = math_ops.conj(x); |
|
|
x = math_ops.conj(x); |
|
|
y = math_ops.conj(y); |
|
|
y = math_ops.conj(y); |
|
|
z = math_ops.conj(z); |
|
|
z = math_ops.conj(z); |
|
|
@@ -761,7 +753,7 @@ namespace Tensorflow.Gradients |
|
|
/// <param name="x"></param> |
|
|
/// <param name="x"></param> |
|
|
/// <param name="y"></param> |
|
|
/// <param name="y"></param> |
|
|
/// <returns></returns> |
|
|
/// <returns></returns> |
|
|
private static (Tensor, Tensor) SmartBroadcastGradientArgs(Tensor x, Tensor y) |
|
|
|
|
|
|
|
|
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y) |
|
|
{ |
|
|
{ |
|
|
Tensor sx, sy; |
|
|
Tensor sx, sy; |
|
|
if (x.TensorShape.is_fully_defined() && |
|
|
if (x.TensorShape.is_fully_defined() && |
|
|
@@ -769,6 +761,13 @@ namespace Tensorflow.Gradients |
|
|
{ |
|
|
{ |
|
|
sx = array_ops.shape(x); |
|
|
sx = array_ops.shape(x); |
|
|
sy = array_ops.shape(y); |
|
|
sy = array_ops.shape(y); |
|
|
|
|
|
|
|
|
|
|
|
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); |
|
|
|
|
|
return new[] |
|
|
|
|
|
{ |
|
|
|
|
|
(sx, rx, true), |
|
|
|
|
|
(sy, ry, true) |
|
|
|
|
|
}; |
|
|
} |
|
|
} |
|
|
else |
|
|
else |
|
|
{ |
|
|
{ |
|
|
@@ -776,7 +775,7 @@ namespace Tensorflow.Gradients |
|
|
sy = array_ops.shape_internal(y, optimize: false); |
|
|
sy = array_ops.shape_internal(y, optimize: false); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return (sx, sy); |
|
|
|
|
|
|
|
|
throw new NotImplementedException(""); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |