diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs
index dc883a75..ba543953 100644
--- a/src/TensorFlowNET.Core/APIs/tf.layers.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs
@@ -100,6 +100,32 @@ namespace Tensorflow
return layer.apply(inputs, training: training);
}
+
+ ///
+ /// Max pooling layer for 2D inputs (e.g. images).
+ ///
+ /// The tensor over which to pool. Must have rank 4.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor max_pooling2d(Tensor inputs,
+ int[] pool_size,
+ int[] strides,
+ string padding = "valid",
+ string data_format = "channels_last",
+ string name = null)
+ {
+ var layer = new MaxPooling2D(pool_size: pool_size,
+ strides: strides,
+ padding: padding,
+ data_format: data_format,
+ name: name);
+
+ return layer.apply(inputs);
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 44203906..5a940afe 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
namespace Tensorflow
@@ -27,19 +28,21 @@ namespace Tensorflow
public static IActivation relu => new relu();
- public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
- RefVariable scale,
- RefVariable offset,
- Tensor mean = null,
- Tensor variance = null,
- float epsilon = 0.001f,
- string data_format = "NHWC",
- bool is_training = true,
- string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
- epsilon: epsilon,
- data_format: data_format,
- is_training: is_training,
- name: name);
+ public static Tensor[] fused_batch_norm(Tensor x,
+ RefVariable scale,
+ RefVariable offset,
+ Tensor mean = null,
+ Tensor variance = null,
+ float epsilon = 0.001f,
+ string data_format = "NHWC",
+ bool is_training = true,
+ string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
+ epsilon: epsilon,
+ data_format: data_format,
+ is_training: is_training,
+ name: name);
+
+ public static Tensor max_pool() => gen_nn_ops.max_pool();
}
}
}
diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs
index ea5bf790..57c3f67f 100644
--- a/src/TensorFlowNET.Core/Framework/smart_module.cs
+++ b/src/TensorFlowNET.Core/Framework/smart_module.cs
@@ -6,9 +6,9 @@ namespace Tensorflow.Framework
{
public class smart_module
{
- public static object smart_cond(Tensor pred,
- Func<(Tensor, Tensor, Tensor)> true_fn = null,
- Func<(Tensor, Tensor, Tensor)> false_fn = null,
+ public static Tensor[] smart_cond(Tensor pred,
+ Func true_fn = null,
+ Func false_fn = null,
string name = null)
{
return control_flow_ops.cond(pred,
@@ -17,9 +17,12 @@ namespace Tensorflow.Framework
name: name);
}
- public static bool smart_constant_value(Tensor pred)
+ public static bool? smart_constant_value(Tensor pred)
{
var pred_value = tensor_util.constant_value(pred);
+ if (pred_value is null)
+ return null;
+
return pred_value;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
index 2e442e65..7e722ff2 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using Tensorflow.Keras.Utils;
@@ -34,6 +35,7 @@ namespace Tensorflow.Keras.Engine
protected string _name;
protected string _base_name;
protected bool _compute_previous_mask;
+ protected List _updates;
public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
@@ -45,6 +47,7 @@ namespace Tensorflow.Keras.Engine
_init_set_name(name);
_trainable_weights = new List();
_compute_previous_mask = false;
+ _updates = new List();
}
public Tensor __call__(Tensor inputs,
@@ -142,6 +145,12 @@ namespace Tensorflow.Keras.Engine
return variable;
}
+ protected virtual void add_update(Tensor[] updates, bool inputs = false)
+ {
+ var updates_op = updates.Select(x => x.op).ToArray();
+ _updates.AddRange(updates_op);
+ }
+
protected virtual void _init_set_name(string name)
{
string base_name = name;
diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
index 1223e350..64f44386 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
@@ -132,6 +132,7 @@ namespace Tensorflow.Keras.Layers
if (fused)
{
outputs = _fused_batch_norm(inputs, training: training);
+ return outputs;
}
throw new NotImplementedException("BatchNormalization call");
@@ -142,7 +143,7 @@ namespace Tensorflow.Keras.Layers
var beta = this.beta;
var gamma = this.gamma;
- Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () =>
+ Func _fused_batch_norm_training = () =>
{
return tf.nn.fused_batch_norm(
inputs,
@@ -152,7 +153,7 @@ namespace Tensorflow.Keras.Layers
data_format: _data_format);
};
- Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () =>
+ Func _fused_batch_norm_inference = () =>
{
return tf.nn.fused_batch_norm(
inputs,
@@ -165,9 +166,41 @@ namespace Tensorflow.Keras.Layers
data_format: _data_format);
};
- tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
+ var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
+ var (output, mean, variance) = (results[0], results[1], results[2]);
+ var training_value = tf_utils.constant_value(training);
- throw new NotImplementedException("_fused_batch_norm");
+ Tensor momentum_tensor;
+ if (training_value == null)
+ {
+ momentum_tensor = tf_utils.smart_cond(training,
+ () => new float[] { momentum }, () => new float[] { 1.0f })[0];
+ }
+ else
+ {
+ momentum_tensor = ops.convert_to_tensor(momentum);
+ }
+
+ if(training_value == null)
+ {
+ var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor);
+ var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor);
+ add_update(new Tensor[] { mean_update }, inputs: true);
+ add_update(new Tensor[] { variance_update }, inputs: true);
+ }
+
+ return output;
+ }
+
+ public Tensor _assign_moving_average(RefVariable variable, Tensor value, Tensor momentum)
+ {
+ return Python.with(ops.name_scope(null, "AssignMovingAvg", new { variable, value, momentum }), scope =>
+ {
+ // var cm = ops.colocate_with(variable);
+ var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay");
+ var update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay;
+ return state_ops.assign_sub(variable, update_delta, name: scope);
+ });
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs
new file mode 100644
index 00000000..649c1a33
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.tf;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class MaxPooling2D : Pooling2D
+ {
+ public MaxPooling2D(
+ int[] pool_size,
+ int[] strides,
+ string padding = "valid",
+ string data_format = null,
+ string name = null) : base(nn.max_pool, pool_size,
+ strides,
+ padding: padding,
+ data_format: data_format,
+ name: name)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
new file mode 100644
index 00000000..1bdb769b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
@@ -0,0 +1,33 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Utils;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class Pooling2D : Tensorflow.Layers.Layer
+ {
+ private Func pool_function;
+ private int[] pool_size;
+ private int[] strides;
+ private string padding;
+ private string data_format;
+ private InputSpec input_spec;
+
+ public Pooling2D(Func pool_function,
+ int[] pool_size,
+ int[] strides,
+ string padding = "valid",
+ string data_format = null,
+ string name = null) : base(name: name)
+ {
+ this.pool_function = pool_function;
+ this.pool_size = conv_utils.normalize_tuple(pool_size, 2, "pool_size");
+ this.strides = conv_utils.normalize_tuple(strides, 2, "strides");
+ this.padding = conv_utils.normalize_padding(padding);
+ this.data_format = conv_utils.normalize_data_format(data_format);
+ this.input_spec = new InputSpec(ndim: 4);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs
index ef348d1b..790470ee 100644
--- a/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs
+++ b/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs
@@ -29,5 +29,20 @@ namespace Tensorflow.Keras.Utils
else
throw new ValueError($"Invalid data_format: {data_format}");
}
+
+ public static int[] normalize_tuple(int[] value, int n, string name)
+ {
+ return value;
+ }
+
+ public static string normalize_padding(string value)
+ {
+ return value.ToLower();
+ }
+
+ public static string normalize_data_format(string value)
+ {
+ return value.ToLower();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
index 4e155493..c57344c2 100644
--- a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
+++ b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
@@ -13,14 +13,19 @@ namespace Tensorflow.Keras.Utils
return tensors.Select(x => is_symbolic_tensor(x)).Count() == tensors.Length;
}
+ public static bool? constant_value(Tensor pred)
+ {
+ return smart_module.smart_constant_value(pred);
+ }
+
public static bool is_symbolic_tensor(Tensor tensor)
{
return true;
}
- public static object smart_cond(Tensor pred,
- Func<(Tensor, Tensor, Tensor)> true_fn = null,
- Func<(Tensor, Tensor, Tensor)> false_fn = null,
+ public static Tensor[] smart_cond(Tensor pred,
+ Func true_fn = null,
+ Func false_fn = null,
string name = null)
{
return smart_module.smart_cond(pred,
diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs
index 1ca856f0..17205c51 100644
--- a/src/TensorFlowNET.Core/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Layers/Layer.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using Tensorflow.Keras.Engine;
@@ -55,11 +56,23 @@ namespace Tensorflow.Layers
var outputs = base.__call__(inputs, training: training);
// Update global default collections.
- //_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS);
+ _add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS });
return outputs;
}
+ protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)
+ {
+ foreach(var name in collection_list)
+ {
+ var collection = ops.get_collection_ref(name) as List