diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs
index 670731e0..d8447163 100644
--- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs
@@ -36,7 +36,7 @@ namespace Tensorflow.Gradients
///
///
[RegisterGradient("Switch")]
- public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
+ public Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
{
throw new NotImplementedException("_SwitchGrad");
//graph = ops.get_default_graph()
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
index 7252301a..a4508d3c 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
@@ -108,7 +108,7 @@ namespace Tensorflow
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
- if(tf.get_default_graph()._nodes_by_name.Count > 18505)
+ if(tf.get_default_graph()._nodes_by_name.Count > 18577)
{
}
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index 7b5d2ea7..967b3c21 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -166,6 +166,94 @@ namespace Tensorflow.Gradients
};
}
+ [RegisterGradient("FusedBatchNorm")]
+ public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
+ => _BaseFusedBatchNormGrad(op, 0, grads);
+
+ ///
+ /// Return the gradients for the 3 inputs of BatchNorm.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads)
+ {
+ var x = op.inputs[0];
+ var grad_y = grads[0];
+ var scale = op.inputs[1];
+ var epsilon = op.get_attr("epsilon");
+ var data_format = op.get_attr("data_format");
+ var is_training = op.get_attr("is_training");
+ Func grad_fun = null;
+
+ switch (version)
+ {
+ case 2:
+ throw new NotImplementedException("");
+ case 1:
+ throw new NotImplementedException("");
+ default:
+ grad_fun = gen_nn_ops.fused_batch_norm_grad;
+ break;
+ }
+
+ if (is_training)
+ {
+ return grad_fun(new FusedBatchNormParams
+ {
+ YBackprop = grad_y,
+ X = x,
+ Scale = scale,
+ ReserveSpace1 = op.outputs[3],
+ ReserveSpace2 = op.outputs[4],
+ ReserveSpace3 = version == 2 ? op.outputs[5] : null,
+ Epsilon = epsilon,
+ DataFormat = data_format,
+ IsTraining = is_training
+ });
+ }
+ else
+ {
+ var pop_mean = op.inputs[3];
+ var pop_var = op.inputs[4];
+ if (data_format == "NCHW")
+ throw new NotImplementedException("");
+
+ var results = grad_fun(new FusedBatchNormParams
+ {
+ YBackprop = grad_y,
+ X = x,
+ Scale = scale,
+ ReserveSpace1 = op.outputs[3],
+ ReserveSpace2 = op.outputs[4],
+ ReserveSpace3 = version == 2 ? op.outputs[5] : null,
+ Epsilon = epsilon,
+ DataFormat = data_format,
+ IsTraining = is_training
+ });
+
+ var (dx, dscale, doffset) = (results[0], results[1], results[2]);
+ if (data_format == "NCHW")
+ throw new NotImplementedException("");
+
+ return new Tensor[]
+ {
+ dx,
+ dscale,
+ doffset,
+ null,
+ null
+ };
+ }
+ }
+
+ [RegisterGradient("BatchNormWithGlobalNormalization")]
+ public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads)
+ {
+ throw new NotImplementedException("BatchNormWithGlobalNormalization");
+ }
+
private static bool IsZero(Tensor g)
{
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type))
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
index aa314efb..ce2295c8 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -27,20 +27,6 @@ namespace Tensorflow.Operations
///
public class CondContext : ControlFlowContext, IProtoBuf
{
-
-
- ///
- /// The boolean tensor for the cond predicate
- ///
- private Tensor _pred;
-
- public Tensor pred => _pred;
-
- ///
- /// 0 or 1 representing this branch
- ///
- private int _branch;
-
private Dictionary _external_values = new Dictionary();
///
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 2a76c52c..c076cbc7 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -45,10 +45,19 @@ namespace Tensorflow.Operations
/// The predicate tensor in this branch
///
protected Tensor _pivot;
- public Tensor pivot
- {
- get => _pivot;
- }
+ public Tensor pivot => _pivot;
+
+ ///
+ /// The boolean tensor for the cond predicate
+ ///
+ protected Tensor _pred;
+ public Tensor pred => _pred;
+
+ ///
+ /// 0 or 1 representing this branch
+ ///
+ protected int _branch;
+ public int branch => _branch;
protected Stack _context_stack;
protected ControlFlowContext _outer_context;
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs
new file mode 100644
index 00000000..689fa5fe
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs
@@ -0,0 +1,27 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations
+{
+ public class FusedBatchNormParams
+ {
+ public string Name { get; set; }
+ public Tensor YBackprop { get; set; }
+ public Tensor X { get; set; }
+ public Tensor Scale { get; set; }
+ public Tensor ReserveSpace1 { get; set; }
+ public Tensor ReserveSpace2 { get; set; }
+ public Tensor ReserveSpace3 { get; set; }
+ public float Epsilon { get; set; }
+ public string DataFormat { get; set; }
+ public bool IsTraining { get; set; }
+
+ public FusedBatchNormParams()
+ {
+ Epsilon = 0.0001f;
+ DataFormat = "NHWC";
+ IsTraining = true;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 82085683..4e376d19 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -156,6 +156,35 @@ namespace Tensorflow.Operations
return op.output;
}
+ ///
+ /// Gradient for batch normalization.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params)
+ {
+ var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new
+ {
+ y_backprop = @params.YBackprop,
+ x = @params.X,
+ scale = @params.Scale,
+ reserve_space_1 = @params.ReserveSpace1,
+ reserve_space_2 = @params.ReserveSpace2,
+ epsilon = @params.Epsilon,
+ data_format = @params.DataFormat,
+ is_training = @params.IsTraining
+ });
+ return op.outputs;
+ }
+
public static Tensor[] fused_batch_norm(Tensor x,
Tensor scale,
Tensor offset,
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index caf5ac18..6118602c 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -218,6 +218,9 @@ namespace Tensorflow
return grouped_inputs.ToArray();
}
+ public T get_attr(string name)
+ => (T)get_attr(name);
+
public object get_attr(string name)
{
AttrValue x = null;
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index 571457b9..54ccf590 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -557,8 +557,31 @@ namespace Tensorflow
throw new NotImplementedException("ZerosLikeOutsideLoop");
return array_ops.zeros_like(val, optimize: false);
}
-
- throw new NotImplementedException("ZerosLikeOutsideLoop");
+ else
+ {
+ var op_ctxt = op._get_control_flow_context();
+ if(op_ctxt != null)
+ {
+ // We are in a cond context. Use a switch to create zeros only when needed.
+ var pred = op_ctxt.pred;
+ var branch = op_ctxt.branch;
+ var switch_val = @switch(op.inputs[0], pred)[1 - branch];
+ var pivot = array_ops.identity(switch_val);
+ if (val.dtype == dtypes.resource)
+ throw new NotImplementedException("");
+ var zeros_shape = array_ops.shape_internal(switch_val, optimize: false);
+ // Ensure ops created within array_ops.zeros are dominated by switch in
+ // cond context.
+ return tf_with(ops.control_dependencies(new[] { pivot }), delegate
+ {
+ return array_ops.zeros(zeros_shape, dtype: val.dtype);
+ });
+ }
+ else
+ {
+ return array_ops.zeros_like(val, optimize: false);
+ }
+ }
}
///
diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs
index fe0dc5e9..3827229d 100644
--- a/src/TensorFlowNET.Core/Tensors/dtypes.cs
+++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs
@@ -33,6 +33,7 @@ namespace Tensorflow
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
+ public static TF_DataType resource = TF_DataType.TF_RESOURCE;
///
///