From 116f21728c89eb48e49a613836c0b4a52db4b96f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 30 Aug 2019 05:31:25 -0500 Subject: [PATCH] fix gen_nn_ops.fused_batch_norm return values. --- .../Framework/smart_module.cs | 23 +++++++++++++------ .../Operations/NnOps/gen_nn_ops.cs | 2 +- .../Operations/nn_impl.py.cs | 21 ++++++++++++++++- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs index 908acb75..67102cab 100644 --- a/src/TensorFlowNET.Core/Framework/smart_module.cs +++ b/src/TensorFlowNET.Core/Framework/smart_module.cs @@ -20,15 +20,24 @@ namespace Tensorflow.Framework { public class smart_module { - public static Tensor[] smart_cond(Tensor pred, - Func true_fn = null, - Func false_fn = null, + public static Tensor[] smart_cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, string name = null) { - return control_flow_ops.cond(pred, - true_fn: true_fn, - false_fn: false_fn, - name: name); + var pred_value = smart_constant_value(pred); + if (pred_value.HasValue) + { + if (pred_value.Value) + return true_fn() as Tensor[]; + else + return false_fn() as Tensor[]; + } + else + return control_flow_ops.cond(pred, + true_fn: true_fn, + false_fn: false_fn, + name: name); } public static bool? smart_constant_value(Tensor pred) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 49d504ab..c79f89b2 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -156,7 +156,7 @@ namespace Tensorflow.Operations return op.output; } - public static Tensor[] _fused_batch_norm(Tensor x, + public static Tensor[] fused_batch_norm(Tensor x, Tensor scale, Tensor offset, Tensor mean, diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index bd70c10a..368214c1 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -83,6 +83,19 @@ namespace Tensorflow }); } + /// + /// Batch normalization. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// public static Tensor[] fused_batch_norm(Tensor x, RefVariable scale, RefVariable offset, @@ -103,7 +116,7 @@ namespace Tensorflow var min_epsilon = 1.001e-5f; epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; - return gen_nn_ops._fused_batch_norm(x, + var results = gen_nn_ops.fused_batch_norm(x, scale_tensor, offset_tensor, mean, @@ -112,6 +125,12 @@ namespace Tensorflow data_format, is_training, name); + + var y = results[0]; + var batch_mean = results[1]; + var batch_var = results[2]; + + return new[] { y, batch_mean, batch_var }; } ///