| @@ -36,7 +36,7 @@ namespace Tensorflow.Gradients | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [RegisterGradient("Switch")] | [RegisterGradient("Switch")] | ||||
| public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||||
| public Tensor[] _SwitchGrad(Operation op, Tensor[] grads) | |||||
| { | { | ||||
| throw new NotImplementedException("_SwitchGrad"); | throw new NotImplementedException("_SwitchGrad"); | ||||
| //graph = ops.get_default_graph() | //graph = ops.get_default_graph() | ||||
| @@ -108,7 +108,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| // generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
| var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
| if(tf.get_default_graph()._nodes_by_name.Count > 18505) | |||||
| if(tf.get_default_graph()._nodes_by_name.Count > 18577) | |||||
| { | { | ||||
| } | } | ||||
| @@ -166,6 +166,94 @@ namespace Tensorflow.Gradients | |||||
| }; | }; | ||||
| } | } | ||||
| [RegisterGradient("FusedBatchNorm")] | |||||
| public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) | |||||
| => _BaseFusedBatchNormGrad(op, 0, grads); | |||||
| /// <summary> | |||||
| /// Return the gradients for the 3 inputs of BatchNorm. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| /// <param name="version"></param> | |||||
| /// <param name="grads"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads) | |||||
| { | |||||
| var x = op.inputs[0]; | |||||
| var grad_y = grads[0]; | |||||
| var scale = op.inputs[1]; | |||||
| var epsilon = op.get_attr<float>("epsilon"); | |||||
| var data_format = op.get_attr<string>("data_format"); | |||||
| var is_training = op.get_attr<bool>("is_training"); | |||||
| Func<FusedBatchNormParams, Tensor[]> grad_fun = null; | |||||
| switch (version) | |||||
| { | |||||
| case 2: | |||||
| throw new NotImplementedException(""); | |||||
| case 1: | |||||
| throw new NotImplementedException(""); | |||||
| default: | |||||
| grad_fun = gen_nn_ops.fused_batch_norm_grad; | |||||
| break; | |||||
| } | |||||
| if (is_training) | |||||
| { | |||||
| return grad_fun(new FusedBatchNormParams | |||||
| { | |||||
| YBackprop = grad_y, | |||||
| X = x, | |||||
| Scale = scale, | |||||
| ReserveSpace1 = op.outputs[3], | |||||
| ReserveSpace2 = op.outputs[4], | |||||
| ReserveSpace3 = version == 2 ? op.outputs[5] : null, | |||||
| Epsilon = epsilon, | |||||
| DataFormat = data_format, | |||||
| IsTraining = is_training | |||||
| }); | |||||
| } | |||||
| else | |||||
| { | |||||
| var pop_mean = op.inputs[3]; | |||||
| var pop_var = op.inputs[4]; | |||||
| if (data_format == "NCHW") | |||||
| throw new NotImplementedException(""); | |||||
| var results = grad_fun(new FusedBatchNormParams | |||||
| { | |||||
| YBackprop = grad_y, | |||||
| X = x, | |||||
| Scale = scale, | |||||
| ReserveSpace1 = op.outputs[3], | |||||
| ReserveSpace2 = op.outputs[4], | |||||
| ReserveSpace3 = version == 2 ? op.outputs[5] : null, | |||||
| Epsilon = epsilon, | |||||
| DataFormat = data_format, | |||||
| IsTraining = is_training | |||||
| }); | |||||
| var (dx, dscale, doffset) = (results[0], results[1], results[2]); | |||||
| if (data_format == "NCHW") | |||||
| throw new NotImplementedException(""); | |||||
| return new Tensor[] | |||||
| { | |||||
| dx, | |||||
| dscale, | |||||
| doffset, | |||||
| null, | |||||
| null | |||||
| }; | |||||
| } | |||||
| } | |||||
| [RegisterGradient("BatchNormWithGlobalNormalization")] | |||||
| public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| throw new NotImplementedException("BatchNormWithGlobalNormalization"); | |||||
| } | |||||
| private static bool IsZero(Tensor g) | private static bool IsZero(Tensor g) | ||||
| { | { | ||||
| if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) | if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) | ||||
| @@ -27,20 +27,6 @@ namespace Tensorflow.Operations | |||||
| /// </summary> | /// </summary> | ||||
| public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | ||||
| { | { | ||||
| /// <summary> | |||||
| /// The boolean tensor for the cond predicate | |||||
| /// </summary> | |||||
| private Tensor _pred; | |||||
| public Tensor pred => _pred; | |||||
| /// <summary> | |||||
| /// 0 or 1 representing this branch | |||||
| /// </summary> | |||||
| private int _branch; | |||||
| private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -45,10 +45,19 @@ namespace Tensorflow.Operations | |||||
| /// The predicate tensor in this branch | /// The predicate tensor in this branch | ||||
| /// </summary> | /// </summary> | ||||
| protected Tensor _pivot; | protected Tensor _pivot; | ||||
| public Tensor pivot | |||||
| { | |||||
| get => _pivot; | |||||
| } | |||||
| public Tensor pivot => _pivot; | |||||
| /// <summary> | |||||
| /// The boolean tensor for the cond predicate | |||||
| /// </summary> | |||||
| protected Tensor _pred; | |||||
| public Tensor pred => _pred; | |||||
| /// <summary> | |||||
| /// 0 or 1 representing this branch | |||||
| /// </summary> | |||||
| protected int _branch; | |||||
| public int branch => _branch; | |||||
| protected Stack<ControlFlowContext> _context_stack; | protected Stack<ControlFlowContext> _context_stack; | ||||
| protected ControlFlowContext _outer_context; | protected ControlFlowContext _outer_context; | ||||
| @@ -0,0 +1,27 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public class FusedBatchNormParams | |||||
| { | |||||
| public string Name { get; set; } | |||||
| public Tensor YBackprop { get; set; } | |||||
| public Tensor X { get; set; } | |||||
| public Tensor Scale { get; set; } | |||||
| public Tensor ReserveSpace1 { get; set; } | |||||
| public Tensor ReserveSpace2 { get; set; } | |||||
| public Tensor ReserveSpace3 { get; set; } | |||||
| public float Epsilon { get; set; } | |||||
| public string DataFormat { get; set; } | |||||
| public bool IsTraining { get; set; } | |||||
| public FusedBatchNormParams() | |||||
| { | |||||
| Epsilon = 0.0001f; | |||||
| DataFormat = "NHWC"; | |||||
| IsTraining = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -156,6 +156,35 @@ namespace Tensorflow.Operations | |||||
| return op.output; | return op.output; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Gradient for batch normalization. | |||||
| /// </summary> | |||||
| /// <param name="y_backprop"></param> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="scale"></param> | |||||
| /// <param name="reserve_space_1"></param> | |||||
| /// <param name="reserve_space_2"></param> | |||||
| /// <param name="epsilon"></param> | |||||
| /// <param name="data_format"></param> | |||||
| /// <param name="is_training"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params) | |||||
| { | |||||
| var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new | |||||
| { | |||||
| y_backprop = @params.YBackprop, | |||||
| x = @params.X, | |||||
| scale = @params.Scale, | |||||
| reserve_space_1 = @params.ReserveSpace1, | |||||
| reserve_space_2 = @params.ReserveSpace2, | |||||
| epsilon = @params.Epsilon, | |||||
| data_format = @params.DataFormat, | |||||
| is_training = @params.IsTraining | |||||
| }); | |||||
| return op.outputs; | |||||
| } | |||||
| public static Tensor[] fused_batch_norm(Tensor x, | public static Tensor[] fused_batch_norm(Tensor x, | ||||
| Tensor scale, | Tensor scale, | ||||
| Tensor offset, | Tensor offset, | ||||
| @@ -218,6 +218,9 @@ namespace Tensorflow | |||||
| return grouped_inputs.ToArray(); | return grouped_inputs.ToArray(); | ||||
| } | } | ||||
| public T get_attr<T>(string name) | |||||
| => (T)get_attr(name); | |||||
| public object get_attr(string name) | public object get_attr(string name) | ||||
| { | { | ||||
| AttrValue x = null; | AttrValue x = null; | ||||
| @@ -557,8 +557,31 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | throw new NotImplementedException("ZerosLikeOutsideLoop"); | ||||
| return array_ops.zeros_like(val, optimize: false); | return array_ops.zeros_like(val, optimize: false); | ||||
| } | } | ||||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||||
| else | |||||
| { | |||||
| var op_ctxt = op._get_control_flow_context(); | |||||
| if(op_ctxt != null) | |||||
| { | |||||
| // We are in a cond context. Use a switch to create zeros only when needed. | |||||
| var pred = op_ctxt.pred; | |||||
| var branch = op_ctxt.branch; | |||||
| var switch_val = @switch(op.inputs[0], pred)[1 - branch]; | |||||
| var pivot = array_ops.identity(switch_val); | |||||
| if (val.dtype == dtypes.resource) | |||||
| throw new NotImplementedException(""); | |||||
| var zeros_shape = array_ops.shape_internal(switch_val, optimize: false); | |||||
| // Ensure ops created within array_ops.zeros are dominated by switch in | |||||
| // cond context. | |||||
| return tf_with(ops.control_dependencies(new[] { pivot }), delegate | |||||
| { | |||||
| return array_ops.zeros(zeros_shape, dtype: val.dtype); | |||||
| }); | |||||
| } | |||||
| else | |||||
| { | |||||
| return array_ops.zeros_like(val, optimize: false); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -33,6 +33,7 @@ namespace Tensorflow | |||||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | ||||
| public static TF_DataType float16 = TF_DataType.TF_HALF; | public static TF_DataType float16 = TF_DataType.TF_HALF; | ||||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | ||||
| public static TF_DataType resource = TF_DataType.TF_RESOURCE; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||