@@ -170,6 +170,14 @@ namespace Tensorflow.Gradients
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
=> _BaseFusedBatchNormGrad(op, 0, grads);
[RegisterGradient("FusedBatchNormV2")]
public static Tensor[] _FusedBatchNormV2Grad(Operation op, Tensor[] grads)
=> _BaseFusedBatchNormGrad(op, 1, grads);
[RegisterGradient("FusedBatchNormV3")]
public static Tensor[] _FusedBatchNormV3Grad(Operation op, Tensor[] grads)
=> _BaseFusedBatchNormGrad(op, 2, grads);
/// <summary>
/// Return the gradients for the 3 inputs of BatchNorm.
/// </summary>
@@ -190,8 +198,10 @@ namespace Tensorflow.Gradients
switch (version)
{
case 2:
throw new NotImplementedException("");
grad_fun = gen_nn_ops.fused_batch_norm_grad_v3;
break;
case 1:
// grad_fun = gen_nn_ops.fused_batch_norm_grad_v2;
throw new NotImplementedException("");
default:
grad_fun = gen_nn_ops.fused_batch_norm_grad;
@@ -225,8 +235,8 @@ namespace Tensorflow.Gradients
YBackprop = grad_y,
X = x,
Scale = scale,
ReserveSpace1 = op.outputs[3] ,
ReserveSpace2 = op.outputs[4] ,
ReserveSpace1 = pop_mean ,
ReserveSpace2 = pop_var ,
ReserveSpace3 = version == 2 ? op.outputs[5] : null,
Epsilon = epsilon,
DataFormat = data_format,