Browse Source

add _FusedBatchNormGrad

tags/v0.12
Oceania2018 6 years ago
parent
commit
9e414f4aa6
10 changed files with 188 additions and 22 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Gradients/control_flow_grad.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  3. +88
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  4. +0
    -14
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  5. +13
    -4
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  6. +27
    -0
      src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs
  7. +29
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  8. +3
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  9. +25
    -2
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  10. +1
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs

+ 1
- 1
src/TensorFlowNET.Core/Gradients/control_flow_grad.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Gradients
/// </summary>
/// <returns></returns>
[RegisterGradient("Switch")]
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
public Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
{
throw new NotImplementedException("_SwitchGrad");
//graph = ops.get_default_graph()


+ 1
- 1
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
if(tf.get_default_graph()._nodes_by_name.Count > 18505)
if(tf.get_default_graph()._nodes_by_name.Count > 18577)
{

}


+ 88
- 0
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -166,6 +166,94 @@ namespace Tensorflow.Gradients
};
}

[RegisterGradient("FusedBatchNorm")]
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
=> _BaseFusedBatchNormGrad(op, 0, grads);

/// <summary>
/// Return the gradients for the 3 inputs of BatchNorm.
/// </summary>
/// <param name="op"></param>
/// <param name="version"></param>
/// <param name="grads"></param>
/// <returns></returns>
public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads)
{
var x = op.inputs[0];
var grad_y = grads[0];
var scale = op.inputs[1];
var epsilon = op.get_attr<float>("epsilon");
var data_format = op.get_attr<string>("data_format");
var is_training = op.get_attr<bool>("is_training");
Func<FusedBatchNormParams, Tensor[]> grad_fun = null;

switch (version)
{
case 2:
throw new NotImplementedException("");
case 1:
throw new NotImplementedException("");
default:
grad_fun = gen_nn_ops.fused_batch_norm_grad;
break;
}

if (is_training)
{
return grad_fun(new FusedBatchNormParams
{
YBackprop = grad_y,
X = x,
Scale = scale,
ReserveSpace1 = op.outputs[3],
ReserveSpace2 = op.outputs[4],
ReserveSpace3 = version == 2 ? op.outputs[5] : null,
Epsilon = epsilon,
DataFormat = data_format,
IsTraining = is_training
});
}
else
{
var pop_mean = op.inputs[3];
var pop_var = op.inputs[4];
if (data_format == "NCHW")
throw new NotImplementedException("");

var results = grad_fun(new FusedBatchNormParams
{
YBackprop = grad_y,
X = x,
Scale = scale,
ReserveSpace1 = op.outputs[3],
ReserveSpace2 = op.outputs[4],
ReserveSpace3 = version == 2 ? op.outputs[5] : null,
Epsilon = epsilon,
DataFormat = data_format,
IsTraining = is_training
});

var (dx, dscale, doffset) = (results[0], results[1], results[2]);
if (data_format == "NCHW")
throw new NotImplementedException("");

return new Tensor[]
{
dx,
dscale,
doffset,
null,
null
};
}
}

[RegisterGradient("BatchNormWithGlobalNormalization")]
public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads)
{
throw new NotImplementedException("BatchNormWithGlobalNormalization");
}

private static bool IsZero(Tensor g)
{
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type))


+ 0
- 14
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -27,20 +27,6 @@ namespace Tensorflow.Operations
/// </summary>
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext>
{


/// <summary>
/// The boolean tensor for the cond predicate
/// </summary>
private Tensor _pred;

public Tensor pred => _pred;

/// <summary>
/// 0 or 1 representing this branch
/// </summary>
private int _branch;

private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>();

/// <summary>


+ 13
- 4
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -45,10 +45,19 @@ namespace Tensorflow.Operations
/// The predicate tensor in this branch
/// </summary>
protected Tensor _pivot;
public Tensor pivot
{
get => _pivot;
}
public Tensor pivot => _pivot;

/// <summary>
/// The boolean tensor for the cond predicate
/// </summary>
protected Tensor _pred;
public Tensor pred => _pred;

/// <summary>
/// 0 or 1 representing this branch
/// </summary>
protected int _branch;
public int branch => _branch;

protected Stack<ControlFlowContext> _context_stack;
protected ControlFlowContext _outer_context;


+ 27
- 0
src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
public class FusedBatchNormParams
{
public string Name { get; set; }
public Tensor YBackprop { get; set; }
public Tensor X { get; set; }
public Tensor Scale { get; set; }
public Tensor ReserveSpace1 { get; set; }
public Tensor ReserveSpace2 { get; set; }
public Tensor ReserveSpace3 { get; set; }
public float Epsilon { get; set; }
public string DataFormat { get; set; }
public bool IsTraining { get; set; }

public FusedBatchNormParams()
{
Epsilon = 0.0001f;
DataFormat = "NHWC";
IsTraining = true;
}
}
}

+ 29
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -156,6 +156,35 @@ namespace Tensorflow.Operations
return op.output;
}

/// <summary>
/// Gradient for batch normalization.
/// </summary>
/// <param name="y_backprop"></param>
/// <param name="x"></param>
/// <param name="scale"></param>
/// <param name="reserve_space_1"></param>
/// <param name="reserve_space_2"></param>
/// <param name="epsilon"></param>
/// <param name="data_format"></param>
/// <param name="is_training"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params)
{
var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new
{
y_backprop = @params.YBackprop,
x = @params.X,
scale = @params.Scale,
reserve_space_1 = @params.ReserveSpace1,
reserve_space_2 = @params.ReserveSpace2,
epsilon = @params.Epsilon,
data_format = @params.DataFormat,
is_training = @params.IsTraining
});
return op.outputs;
}

public static Tensor[] fused_batch_norm(Tensor x,
Tensor scale,
Tensor offset,


+ 3
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -218,6 +218,9 @@ namespace Tensorflow
return grouped_inputs.ToArray();
}

public T get_attr<T>(string name)
=> (T)get_attr(name);

public object get_attr(string name)
{
AttrValue x = null;


+ 25
- 2
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -557,8 +557,31 @@ namespace Tensorflow
throw new NotImplementedException("ZerosLikeOutsideLoop");
return array_ops.zeros_like(val, optimize: false);
}

throw new NotImplementedException("ZerosLikeOutsideLoop");
else
{
var op_ctxt = op._get_control_flow_context();
if(op_ctxt != null)
{
// We are in a cond context. Use a switch to create zeros only when needed.
var pred = op_ctxt.pred;
var branch = op_ctxt.branch;
var switch_val = @switch(op.inputs[0], pred)[1 - branch];
var pivot = array_ops.identity(switch_val);
if (val.dtype == dtypes.resource)
throw new NotImplementedException("");
var zeros_shape = array_ops.shape_internal(switch_val, optimize: false);
// Ensure ops created within array_ops.zeros are dominated by switch in
// cond context.
return tf_with(ops.control_dependencies(new[] { pivot }), delegate
{
return array_ops.zeros(zeros_shape, dtype: val.dtype);
});
}
else
{
return array_ops.zeros_like(val, optimize: false);
}
}
}

/// <summary>


+ 1
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -33,6 +33,7 @@ namespace Tensorflow
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static TF_DataType resource = TF_DataType.TF_RESOURCE;

/// <summary>
///


Loading…
Cancel
Save