Browse Source

fix gen_nn_ops.fused_batch_norm return values.

tags/v0.12
Oceania2018 6 years ago
parent
commit
116f21728c
3 changed files with 37 additions and 9 deletions
  1. +16
    -7
      src/TensorFlowNET.Core/Framework/smart_module.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  3. +20
    -1
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs

+ 16
- 7
src/TensorFlowNET.Core/Framework/smart_module.cs View File

@@ -20,15 +20,24 @@ namespace Tensorflow.Framework
{
public class smart_module
{
public static Tensor[] smart_cond<T>(Tensor pred,
Func<T[]> true_fn = null,
Func<T[]> false_fn = null,
public static Tensor[] smart_cond<T>(Tensor pred,
Func<T[]> true_fn = null,
Func<T[]> false_fn = null,
string name = null)
{
return control_flow_ops.cond(pred,
true_fn: true_fn,
false_fn: false_fn,
name: name);
var pred_value = smart_constant_value(pred);
if (pred_value.HasValue)
{
if (pred_value.Value)
return true_fn() as Tensor[];
else
return false_fn() as Tensor[];
}
else
return control_flow_ops.cond(pred,
true_fn: true_fn,
false_fn: false_fn,
name: name);
}

public static bool? smart_constant_value(Tensor pred)


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -156,7 +156,7 @@ namespace Tensorflow.Operations
return op.output;
}

public static Tensor[] _fused_batch_norm(Tensor x,
public static Tensor[] fused_batch_norm(Tensor x,
Tensor scale,
Tensor offset,
Tensor mean,


+ 20
- 1
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -83,6 +83,19 @@ namespace Tensorflow
});
}

/// <summary>
/// Batch normalization.
/// </summary>
/// <param name="x"></param>
/// <param name="scale"></param>
/// <param name="offset"></param>
/// <param name="mean"></param>
/// <param name="variance"></param>
/// <param name="epsilon"></param>
/// <param name="data_format"></param>
/// <param name="is_training"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor[] fused_batch_norm(Tensor x,
RefVariable scale,
RefVariable offset,
@@ -103,7 +116,7 @@ namespace Tensorflow
var min_epsilon = 1.001e-5f;
epsilon = epsilon > min_epsilon ? epsilon : min_epsilon;

return gen_nn_ops._fused_batch_norm(x,
var results = gen_nn_ops.fused_batch_norm(x,
scale_tensor,
offset_tensor,
mean,
@@ -112,6 +125,12 @@ namespace Tensorflow
data_format,
is_training,
name);

var y = results[0];
var batch_mean = results[1];
var batch_var = results[2];

return new[] { y, batch_mean, batch_var };
}

/// <summary>


Loading…
Cancel
Save