From ce0d722355670b4f03cba45925110dce64be821a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 18 Oct 2020 20:24:56 -0500 Subject: [PATCH] Fix fused_batch_norm_v3 for eager mode. --- .../Gradients/gradient_exclustions.cs | 1 + src/TensorFlowNET.Core/Keras/Engine/Flatten.cs | 2 +- .../Keras/Engine/Functional.cs | 3 ++- .../Keras/Engine/Layer.Apply.cs | 2 +- .../Engine/Layer.FunctionalConstructionCall.cs | 2 +- src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 2 +- .../Keras/Engine/TensorFlowOpLayer.cs | 4 ++-- .../Keras/Layers/BatchNormalization.cs | 2 +- .../Keras/Layers/Convolutional.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Dropout.cs | 2 +- .../Keras/Layers/Embedding.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/LSTM.cs | 4 ++-- .../Keras/Layers/Pooling2D.cs | 2 +- .../Keras/Layers/Rescaling.cs | 2 +- .../Keras/Layers/ZeroPadding2D.cs | 2 +- .../Operations/NnOps/BasicLSTMCell.cs | 2 +- .../Operations/NnOps/BasicRNNCell.cs | 2 +- .../Operations/NnOps/gen_nn_ops.cs | 17 +++++++++++++++++ .../Tensorflow.Binding.csproj | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.Value.cs | 3 +++ 21 files changed, 42 insertions(+), 20 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs b/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs index c6eab3b3..70d5a32f 100644 --- a/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs +++ b/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs @@ -20,6 +20,7 @@ namespace Tensorflow.Gradients public static int[] OpGradientUnusedOutputIndices(string op_name) => op_name switch { + "FusedBatchNormV3" => new[] { 0, 1, 2 }, "ReadVariableOp" => new int[0], "SoftmaxCrossEntropyWithLogits" => new[] { 0 }, "TensorArrayConcat" => new[] { 0 }, diff --git a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs index d8197285..2f1aae0d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine _channels_first = args.DataFormat == "channels_first"; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { if (_channels_first) { diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs index e990fb86..4180e37e 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs @@ -268,7 +268,7 @@ namespace Tensorflow.Keras.Engine nodes_in_decreasing_depth.Insert(nodes_in_decreasing_depth.Count, node); } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { return run_internal_graph(inputs, is_training); } @@ -305,6 +305,7 @@ namespace Tensorflow.Keras.Engine tensor_dict[node.FlatInputIds[0]] = new Tensor[0]; var outputs = node.Layer.Apply(layer_inputs, is_training: training); + // Update tensor_dict. foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y).ToArray(); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs index a2b6ef2d..06897cb8 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine if (!built) MaybeBuild(inputs); - outputs = CallFn(inputs, state: state, is_training: is_training); + outputs = Call(inputs, state: state, is_training: is_training); outputs = _set_connectivity_metadata_(inputs, outputs); _handle_activity_regularization(inputs, outputs); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs index a32952cb..a11c8850 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine if (!dynamic) throw new NotImplementedException(""); - outputs = CallFn(inputs); + outputs = Call(inputs); outputs = _set_connectivity_metadata_(inputs, outputs); _handle_activity_regularization(inputs, outputs); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index d64f0d1c..27537db6 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - protected virtual Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs b/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs index 423287e0..c42dd9f2 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs @@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine built = true; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { - return base.CallFn(inputs, state, is_training); + return base.Call(inputs, state, is_training); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index fc32a792..1ac0e649 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { Tensor outputs = null; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs b/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs index d8b4bad9..f4bc77c0 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs @@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false) { var outputs = _convolution_op.Apply(inputs, kernel); if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 7eed3a63..e3e8b8e3 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs index ec4cebae..1b9f138f 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { var output = tf_utils.smart_cond(is_training, () => tf.nn.dropout(inputs, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index ef85d8a4..9962ff25 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs index 266081c0..87728fdf 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs @@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers .ToArray(); } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { - return base.CallFn(inputs, state: state, is_training: is_training); + return base.Call(inputs, state: state, is_training: is_training); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index a099caf2..daf57b1e 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers input_spec = new InputSpec(ndim: 4); } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { int[] pool_shape; int[] strides; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs index b542bcbd..ba8f3901 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { scale = math_ops.cast(args.Scale, args.DType); offset = math_ops.cast(args.Offset, args.DType); diff --git a/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs b/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs index 2790857b..07dc6b03 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs @@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { return tf.keras.backend.spatial_2d_padding(inputs, padding: padding, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index 7a8b4311..ca4a7df2 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -74,7 +74,7 @@ namespace Tensorflow /// /// /// - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { var one = constant_op.constant(1, dtype: dtypes.int32); // Parameters of gates are concatenated into one multiply for efficiency. diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index f1f49792..987f84c5 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -67,7 +67,7 @@ namespace Tensorflow built = true; } - protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { // Most basic RNN: output = new_state = act(W * input + U * state + B). var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index fb19ab4e..ea82ceae 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -321,6 +321,23 @@ namespace Tensorflow.Operations bool is_training = true, string name = null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "FusedBatchNormV3", name, + null, + x, + scale, + offset, + mean, + variance, + "epsilon", epsilon, + "data_format", data_format, + "is_training", is_training); + + return results; + } + var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormV3", name: name, args: new { x, diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index e1b9bd86..062d5037 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -79,7 +79,7 @@ https://tensorflownet.readthedocs.io - + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index f0905fb6..f8d6c1a3 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -158,6 +158,9 @@ namespace Tensorflow UnmanagedStorage storage; switch (dtype) { + case TF_DataType.TF_BOOL: + storage = new UnmanagedStorage(NPTypeCode.Boolean); + break; case TF_DataType.TF_STRING: return np.array(StringBytes()[0]); case TF_DataType.TF_INT32: