From 9e414f4aa6ede15c36929b2b2463f76591af90be Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 1 Oct 2019 17:00:14 -0500 Subject: [PATCH] add _FusedBatchNormGrad --- .../Gradients/control_flow_grad.cs | 2 +- .../Gradients/gradients_util.cs | 2 +- src/TensorFlowNET.Core/Gradients/nn_grad.cs | 88 +++++++++++++++++++ .../Operations/ControlFlows/CondContext.cs | 14 --- .../ControlFlows/ControlFlowContext.cs | 17 +++- .../Operations/NnOps/FusedBatchNormParams.cs | 27 ++++++ .../Operations/NnOps/gen_nn_ops.cs | 29 ++++++ .../Operations/Operation.cs | 3 + .../Operations/control_flow_ops.py.cs | 27 +++++- src/TensorFlowNET.Core/Tensors/dtypes.cs | 1 + 10 files changed, 188 insertions(+), 22 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs index 670731e0..d8447163 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Gradients /// /// [RegisterGradient("Switch")] - public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) + public Tensor[] _SwitchGrad(Operation op, Tensor[] grads) { throw new NotImplementedException("_SwitchGrad"); //graph = ops.get_default_graph() diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 7252301a..a4508d3c 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -108,7 +108,7 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(tf.get_default_graph()._nodes_by_name.Count > 18505) + if(tf.get_default_graph()._nodes_by_name.Count > 18577) { } diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index 7b5d2ea7..967b3c21 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -166,6 +166,94 @@ namespace Tensorflow.Gradients }; } + [RegisterGradient("FusedBatchNorm")] + public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) + => _BaseFusedBatchNormGrad(op, 0, grads); + + /// + /// Return the gradients for the 3 inputs of BatchNorm. + /// + /// + /// + /// + /// + public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads) + { + var x = op.inputs[0]; + var grad_y = grads[0]; + var scale = op.inputs[1]; + var epsilon = op.get_attr("epsilon"); + var data_format = op.get_attr("data_format"); + var is_training = op.get_attr("is_training"); + Func grad_fun = null; + + switch (version) + { + case 2: + throw new NotImplementedException(""); + case 1: + throw new NotImplementedException(""); + default: + grad_fun = gen_nn_ops.fused_batch_norm_grad; + break; + } + + if (is_training) + { + return grad_fun(new FusedBatchNormParams + { + YBackprop = grad_y, + X = x, + Scale = scale, + ReserveSpace1 = op.outputs[3], + ReserveSpace2 = op.outputs[4], + ReserveSpace3 = version == 2 ? op.outputs[5] : null, + Epsilon = epsilon, + DataFormat = data_format, + IsTraining = is_training + }); + } + else + { + var pop_mean = op.inputs[3]; + var pop_var = op.inputs[4]; + if (data_format == "NCHW") + throw new NotImplementedException(""); + + var results = grad_fun(new FusedBatchNormParams + { + YBackprop = grad_y, + X = x, + Scale = scale, + ReserveSpace1 = op.outputs[3], + ReserveSpace2 = op.outputs[4], + ReserveSpace3 = version == 2 ? op.outputs[5] : null, + Epsilon = epsilon, + DataFormat = data_format, + IsTraining = is_training + }); + + var (dx, dscale, doffset) = (results[0], results[1], results[2]); + if (data_format == "NCHW") + throw new NotImplementedException(""); + + return new Tensor[] + { + dx, + dscale, + doffset, + null, + null + }; + } + } + + [RegisterGradient("BatchNormWithGlobalNormalization")] + public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads) + { + throw new NotImplementedException("BatchNormWithGlobalNormalization"); + } + private static bool IsZero(Tensor g) { if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index aa314efb..ce2295c8 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -27,20 +27,6 @@ namespace Tensorflow.Operations /// public class CondContext : ControlFlowContext, IProtoBuf { - - - /// - /// The boolean tensor for the cond predicate - /// - private Tensor _pred; - - public Tensor pred => _pred; - - /// - /// 0 or 1 representing this branch - /// - private int _branch; - private Dictionary _external_values = new Dictionary(); /// diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 2a76c52c..c076cbc7 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -45,10 +45,19 @@ namespace Tensorflow.Operations /// The predicate tensor in this branch /// protected Tensor _pivot; - public Tensor pivot - { - get => _pivot; - } + public Tensor pivot => _pivot; + + /// + /// The boolean tensor for the cond predicate + /// + protected Tensor _pred; + public Tensor pred => _pred; + + /// + /// 0 or 1 representing this branch + /// + protected int _branch; + public int branch => _branch; protected Stack _context_stack; protected ControlFlowContext _outer_context; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs new file mode 100644 index 00000000..689fa5fe --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public class FusedBatchNormParams + { + public string Name { get; set; } + public Tensor YBackprop { get; set; } + public Tensor X { get; set; } + public Tensor Scale { get; set; } + public Tensor ReserveSpace1 { get; set; } + public Tensor ReserveSpace2 { get; set; } + public Tensor ReserveSpace3 { get; set; } + public float Epsilon { get; set; } + public string DataFormat { get; set; } + public bool IsTraining { get; set; } + + public FusedBatchNormParams() + { + Epsilon = 0.0001f; + DataFormat = "NHWC"; + IsTraining = true; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 82085683..4e376d19 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -156,6 +156,35 @@ namespace Tensorflow.Operations return op.output; } + /// + /// Gradient for batch normalization. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params) + { + var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new + { + y_backprop = @params.YBackprop, + x = @params.X, + scale = @params.Scale, + reserve_space_1 = @params.ReserveSpace1, + reserve_space_2 = @params.ReserveSpace2, + epsilon = @params.Epsilon, + data_format = @params.DataFormat, + is_training = @params.IsTraining + }); + return op.outputs; + } + public static Tensor[] fused_batch_norm(Tensor x, Tensor scale, Tensor offset, diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index caf5ac18..6118602c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -218,6 +218,9 @@ namespace Tensorflow return grouped_inputs.ToArray(); } + public T get_attr(string name) + => (T)get_attr(name); + public object get_attr(string name) { AttrValue x = null; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 571457b9..54ccf590 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -557,8 +557,31 @@ namespace Tensorflow throw new NotImplementedException("ZerosLikeOutsideLoop"); return array_ops.zeros_like(val, optimize: false); } - - throw new NotImplementedException("ZerosLikeOutsideLoop"); + else + { + var op_ctxt = op._get_control_flow_context(); + if(op_ctxt != null) + { + // We are in a cond context. Use a switch to create zeros only when needed. + var pred = op_ctxt.pred; + var branch = op_ctxt.branch; + var switch_val = @switch(op.inputs[0], pred)[1 - branch]; + var pivot = array_ops.identity(switch_val); + if (val.dtype == dtypes.resource) + throw new NotImplementedException(""); + var zeros_shape = array_ops.shape_internal(switch_val, optimize: false); + // Ensure ops created within array_ops.zeros are dominated by switch in + // cond context. + return tf_with(ops.control_dependencies(new[] { pivot }), delegate + { + return array_ops.zeros(zeros_shape, dtype: val.dtype); + }); + } + else + { + return array_ops.zeros_like(val, optimize: false); + } + } } /// diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index fe0dc5e9..3827229d 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -33,6 +33,7 @@ namespace Tensorflow public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static TF_DataType resource = TF_DataType.TF_RESOURCE; /// ///