| @@ -36,7 +36,7 @@ namespace Tensorflow.Gradients | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [RegisterGradient("Switch")] | |||
| public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||
| public Tensor[] _SwitchGrad(Operation op, Tensor[] grads) | |||
| { | |||
| throw new NotImplementedException("_SwitchGrad"); | |||
| //graph = ops.get_default_graph() | |||
| @@ -108,7 +108,7 @@ namespace Tensorflow | |||
| { | |||
| // generate gradient subgraph for op. | |||
| var op = queue.Dequeue(); | |||
| if(tf.get_default_graph()._nodes_by_name.Count > 18505) | |||
| if(tf.get_default_graph()._nodes_by_name.Count > 18577) | |||
| { | |||
| } | |||
| @@ -166,6 +166,94 @@ namespace Tensorflow.Gradients | |||
| }; | |||
| } | |||
| [RegisterGradient("FusedBatchNorm")] | |||
| public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) | |||
| => _BaseFusedBatchNormGrad(op, 0, grads); | |||
| /// <summary> | |||
| /// Return the gradients for the 3 inputs of BatchNorm. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <param name="version"></param> | |||
| /// <param name="grads"></param> | |||
| /// <returns></returns> | |||
| public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads) | |||
| { | |||
| var x = op.inputs[0]; | |||
| var grad_y = grads[0]; | |||
| var scale = op.inputs[1]; | |||
| var epsilon = op.get_attr<float>("epsilon"); | |||
| var data_format = op.get_attr<string>("data_format"); | |||
| var is_training = op.get_attr<bool>("is_training"); | |||
| Func<FusedBatchNormParams, Tensor[]> grad_fun = null; | |||
| switch (version) | |||
| { | |||
| case 2: | |||
| throw new NotImplementedException(""); | |||
| case 1: | |||
| throw new NotImplementedException(""); | |||
| default: | |||
| grad_fun = gen_nn_ops.fused_batch_norm_grad; | |||
| break; | |||
| } | |||
| if (is_training) | |||
| { | |||
| return grad_fun(new FusedBatchNormParams | |||
| { | |||
| YBackprop = grad_y, | |||
| X = x, | |||
| Scale = scale, | |||
| ReserveSpace1 = op.outputs[3], | |||
| ReserveSpace2 = op.outputs[4], | |||
| ReserveSpace3 = version == 2 ? op.outputs[5] : null, | |||
| Epsilon = epsilon, | |||
| DataFormat = data_format, | |||
| IsTraining = is_training | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| var pop_mean = op.inputs[3]; | |||
| var pop_var = op.inputs[4]; | |||
| if (data_format == "NCHW") | |||
| throw new NotImplementedException(""); | |||
| var results = grad_fun(new FusedBatchNormParams | |||
| { | |||
| YBackprop = grad_y, | |||
| X = x, | |||
| Scale = scale, | |||
| ReserveSpace1 = op.outputs[3], | |||
| ReserveSpace2 = op.outputs[4], | |||
| ReserveSpace3 = version == 2 ? op.outputs[5] : null, | |||
| Epsilon = epsilon, | |||
| DataFormat = data_format, | |||
| IsTraining = is_training | |||
| }); | |||
| var (dx, dscale, doffset) = (results[0], results[1], results[2]); | |||
| if (data_format == "NCHW") | |||
| throw new NotImplementedException(""); | |||
| return new Tensor[] | |||
| { | |||
| dx, | |||
| dscale, | |||
| doffset, | |||
| null, | |||
| null | |||
| }; | |||
| } | |||
| } | |||
| [RegisterGradient("BatchNormWithGlobalNormalization")] | |||
| public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads) | |||
| { | |||
| throw new NotImplementedException("BatchNormWithGlobalNormalization"); | |||
| } | |||
| private static bool IsZero(Tensor g) | |||
| { | |||
| if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) | |||
| @@ -27,20 +27,6 @@ namespace Tensorflow.Operations | |||
| /// </summary> | |||
| public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | |||
| { | |||
| /// <summary> | |||
| /// The boolean tensor for the cond predicate | |||
| /// </summary> | |||
| private Tensor _pred; | |||
| public Tensor pred => _pred; | |||
| /// <summary> | |||
| /// 0 or 1 representing this branch | |||
| /// </summary> | |||
| private int _branch; | |||
| private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | |||
| /// <summary> | |||
| @@ -45,10 +45,19 @@ namespace Tensorflow.Operations | |||
| /// The predicate tensor in this branch | |||
| /// </summary> | |||
| protected Tensor _pivot; | |||
| public Tensor pivot | |||
| { | |||
| get => _pivot; | |||
| } | |||
| public Tensor pivot => _pivot; | |||
| /// <summary> | |||
| /// The boolean tensor for the cond predicate | |||
| /// </summary> | |||
| protected Tensor _pred; | |||
| public Tensor pred => _pred; | |||
| /// <summary> | |||
| /// 0 or 1 representing this branch | |||
| /// </summary> | |||
| protected int _branch; | |||
| public int branch => _branch; | |||
| protected Stack<ControlFlowContext> _context_stack; | |||
| protected ControlFlowContext _outer_context; | |||
| @@ -0,0 +1,27 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public class FusedBatchNormParams | |||
| { | |||
| public string Name { get; set; } | |||
| public Tensor YBackprop { get; set; } | |||
| public Tensor X { get; set; } | |||
| public Tensor Scale { get; set; } | |||
| public Tensor ReserveSpace1 { get; set; } | |||
| public Tensor ReserveSpace2 { get; set; } | |||
| public Tensor ReserveSpace3 { get; set; } | |||
| public float Epsilon { get; set; } | |||
| public string DataFormat { get; set; } | |||
| public bool IsTraining { get; set; } | |||
| public FusedBatchNormParams() | |||
| { | |||
| Epsilon = 0.0001f; | |||
| DataFormat = "NHWC"; | |||
| IsTraining = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -156,6 +156,35 @@ namespace Tensorflow.Operations | |||
| return op.output; | |||
| } | |||
| /// <summary> | |||
| /// Gradient for batch normalization. | |||
| /// </summary> | |||
| /// <param name="y_backprop"></param> | |||
| /// <param name="x"></param> | |||
| /// <param name="scale"></param> | |||
| /// <param name="reserve_space_1"></param> | |||
| /// <param name="reserve_space_2"></param> | |||
| /// <param name="epsilon"></param> | |||
| /// <param name="data_format"></param> | |||
| /// <param name="is_training"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params) | |||
| { | |||
| var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new | |||
| { | |||
| y_backprop = @params.YBackprop, | |||
| x = @params.X, | |||
| scale = @params.Scale, | |||
| reserve_space_1 = @params.ReserveSpace1, | |||
| reserve_space_2 = @params.ReserveSpace2, | |||
| epsilon = @params.Epsilon, | |||
| data_format = @params.DataFormat, | |||
| is_training = @params.IsTraining | |||
| }); | |||
| return op.outputs; | |||
| } | |||
| public static Tensor[] fused_batch_norm(Tensor x, | |||
| Tensor scale, | |||
| Tensor offset, | |||
| @@ -218,6 +218,9 @@ namespace Tensorflow | |||
| return grouped_inputs.ToArray(); | |||
| } | |||
| public T get_attr<T>(string name) | |||
| => (T)get_attr(name); | |||
| public object get_attr(string name) | |||
| { | |||
| AttrValue x = null; | |||
| @@ -557,8 +557,31 @@ namespace Tensorflow | |||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||
| return array_ops.zeros_like(val, optimize: false); | |||
| } | |||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||
| else | |||
| { | |||
| var op_ctxt = op._get_control_flow_context(); | |||
| if(op_ctxt != null) | |||
| { | |||
| // We are in a cond context. Use a switch to create zeros only when needed. | |||
| var pred = op_ctxt.pred; | |||
| var branch = op_ctxt.branch; | |||
| var switch_val = @switch(op.inputs[0], pred)[1 - branch]; | |||
| var pivot = array_ops.identity(switch_val); | |||
| if (val.dtype == dtypes.resource) | |||
| throw new NotImplementedException(""); | |||
| var zeros_shape = array_ops.shape_internal(switch_val, optimize: false); | |||
| // Ensure ops created within array_ops.zeros are dominated by switch in | |||
| // cond context. | |||
| return tf_with(ops.control_dependencies(new[] { pivot }), delegate | |||
| { | |||
| return array_ops.zeros(zeros_shape, dtype: val.dtype); | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| return array_ops.zeros_like(val, optimize: false); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -33,6 +33,7 @@ namespace Tensorflow | |||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | |||
| public static TF_DataType float16 = TF_DataType.TF_HALF; | |||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | |||
| public static TF_DataType resource = TF_DataType.TF_RESOURCE; | |||
| /// <summary> | |||
| /// | |||