diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index dc883a75..ba543953 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -100,6 +100,32 @@ namespace Tensorflow return layer.apply(inputs, training: training); } + + /// + /// Max pooling layer for 2D inputs (e.g. images). + /// + /// The tensor over which to pool. Must have rank 4. + /// + /// + /// + /// + /// + /// + public static Tensor max_pooling2d(Tensor inputs, + int[] pool_size, + int[] strides, + string padding = "valid", + string data_format = "channels_last", + string name = null) + { + var layer = new MaxPooling2D(pool_size: pool_size, + strides: strides, + padding: padding, + data_format: data_format, + name: name); + + return layer.apply(inputs); + } } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 44203906..5a940afe 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations; using Tensorflow.Operations.Activation; namespace Tensorflow @@ -27,19 +28,21 @@ namespace Tensorflow public static IActivation relu => new relu(); - public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, - RefVariable scale, - RefVariable offset, - Tensor mean = null, - Tensor variance = null, - float epsilon = 0.001f, - string data_format = "NHWC", - bool is_training = true, - string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, - epsilon: epsilon, - data_format: data_format, - is_training: is_training, - name: name); + public static Tensor[] fused_batch_norm(Tensor x, + RefVariable scale, + RefVariable offset, + Tensor mean = null, + Tensor variance = null, + float epsilon = 0.001f, + string data_format = "NHWC", + bool is_training = true, + string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, + epsilon: epsilon, + data_format: data_format, + is_training: is_training, + name: name); + + public static Tensor max_pool() => gen_nn_ops.max_pool(); } } } diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs index ea5bf790..57c3f67f 100644 --- a/src/TensorFlowNET.Core/Framework/smart_module.cs +++ b/src/TensorFlowNET.Core/Framework/smart_module.cs @@ -6,9 +6,9 @@ namespace Tensorflow.Framework { public class smart_module { - public static object smart_cond(Tensor pred, - Func<(Tensor, Tensor, Tensor)> true_fn = null, - Func<(Tensor, Tensor, Tensor)> false_fn = null, + public static Tensor[] smart_cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, string name = null) { return control_flow_ops.cond(pred, @@ -17,9 +17,12 @@ namespace Tensorflow.Framework name: name); } - public static bool smart_constant_value(Tensor pred) + public static bool? smart_constant_value(Tensor pred) { var pred_value = tensor_util.constant_value(pred); + if (pred_value is null) + return null; + return pred_value; } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 2e442e65..7e722ff2 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Keras.Utils; @@ -34,6 +35,7 @@ namespace Tensorflow.Keras.Engine protected string _name; protected string _base_name; protected bool _compute_previous_mask; + protected List _updates; public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) { @@ -45,6 +47,7 @@ namespace Tensorflow.Keras.Engine _init_set_name(name); _trainable_weights = new List(); _compute_previous_mask = false; + _updates = new List(); } public Tensor __call__(Tensor inputs, @@ -142,6 +145,12 @@ namespace Tensorflow.Keras.Engine return variable; } + protected virtual void add_update(Tensor[] updates, bool inputs = false) + { + var updates_op = updates.Select(x => x.op).ToArray(); + _updates.AddRange(updates_op); + } + protected virtual void _init_set_name(string name) { string base_name = name; diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 1223e350..64f44386 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -132,6 +132,7 @@ namespace Tensorflow.Keras.Layers if (fused) { outputs = _fused_batch_norm(inputs, training: training); + return outputs; } throw new NotImplementedException("BatchNormalization call"); @@ -142,7 +143,7 @@ namespace Tensorflow.Keras.Layers var beta = this.beta; var gamma = this.gamma; - Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () => + Func _fused_batch_norm_training = () => { return tf.nn.fused_batch_norm( inputs, @@ -152,7 +153,7 @@ namespace Tensorflow.Keras.Layers data_format: _data_format); }; - Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () => + Func _fused_batch_norm_inference = () => { return tf.nn.fused_batch_norm( inputs, @@ -165,9 +166,41 @@ namespace Tensorflow.Keras.Layers data_format: _data_format); }; - tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); + var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); + var (output, mean, variance) = (results[0], results[1], results[2]); + var training_value = tf_utils.constant_value(training); - throw new NotImplementedException("_fused_batch_norm"); + Tensor momentum_tensor; + if (training_value == null) + { + momentum_tensor = tf_utils.smart_cond(training, + () => new float[] { momentum }, () => new float[] { 1.0f })[0]; + } + else + { + momentum_tensor = ops.convert_to_tensor(momentum); + } + + if(training_value == null) + { + var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor); + var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor); + add_update(new Tensor[] { mean_update }, inputs: true); + add_update(new Tensor[] { variance_update }, inputs: true); + } + + return output; + } + + public Tensor _assign_moving_average(RefVariable variable, Tensor value, Tensor momentum) + { + return Python.with(ops.name_scope(null, "AssignMovingAvg", new { variable, value, momentum }), scope => + { + // var cm = ops.colocate_with(variable); + var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay"); + var update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay; + return state_ops.assign_sub(variable, update_delta, name: scope); + }); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs new file mode 100644 index 00000000..649c1a33 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.tf; + +namespace Tensorflow.Keras.Layers +{ + public class MaxPooling2D : Pooling2D + { + public MaxPooling2D( + int[] pool_size, + int[] strides, + string padding = "valid", + string data_format = null, + string name = null) : base(nn.max_pool, pool_size, + strides, + padding: padding, + data_format: data_format, + name: name) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs new file mode 100644 index 00000000..1bdb769b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Layers +{ + public class Pooling2D : Tensorflow.Layers.Layer + { + private Func pool_function; + private int[] pool_size; + private int[] strides; + private string padding; + private string data_format; + private InputSpec input_spec; + + public Pooling2D(Func pool_function, + int[] pool_size, + int[] strides, + string padding = "valid", + string data_format = null, + string name = null) : base(name: name) + { + this.pool_function = pool_function; + this.pool_size = conv_utils.normalize_tuple(pool_size, 2, "pool_size"); + this.strides = conv_utils.normalize_tuple(strides, 2, "strides"); + this.padding = conv_utils.normalize_padding(padding); + this.data_format = conv_utils.normalize_data_format(data_format); + this.input_spec = new InputSpec(ndim: 4); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs index ef348d1b..790470ee 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs @@ -29,5 +29,20 @@ namespace Tensorflow.Keras.Utils else throw new ValueError($"Invalid data_format: {data_format}"); } + + public static int[] normalize_tuple(int[] value, int n, string name) + { + return value; + } + + public static string normalize_padding(string value) + { + return value.ToLower(); + } + + public static string normalize_data_format(string value) + { + return value.ToLower(); + } } } diff --git a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs index 4e155493..c57344c2 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs @@ -13,14 +13,19 @@ namespace Tensorflow.Keras.Utils return tensors.Select(x => is_symbolic_tensor(x)).Count() == tensors.Length; } + public static bool? constant_value(Tensor pred) + { + return smart_module.smart_constant_value(pred); + } + public static bool is_symbolic_tensor(Tensor tensor) { return true; } - public static object smart_cond(Tensor pred, - Func<(Tensor, Tensor, Tensor)> true_fn = null, - Func<(Tensor, Tensor, Tensor)> false_fn = null, + public static Tensor[] smart_cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, string name = null) { return smart_module.smart_cond(pred, diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 1ca856f0..17205c51 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Keras.Engine; @@ -55,11 +56,23 @@ namespace Tensorflow.Layers var outputs = base.__call__(inputs, training: training); // Update global default collections. - //_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS); + _add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS }); return outputs; } + protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) + { + foreach(var name in collection_list) + { + var collection = ops.get_collection_ref(name) as List; + + foreach (var element in elements) + if (!collection.Contains(element)) + collection.Add(element); + } + } + protected virtual RefVariable add_weight(string name, int[] shape, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 3c233e8d..23799892 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -63,14 +63,23 @@ namespace Tensorflow.Operations } } - public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn) + public (T[], Tensor[]) BuildCondBranch(Func fn) { // Add the subgraph defined by fn() to the graph. var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); var original_result = fn(); var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); - return original_result; + switch (original_result) + { + case Tensor[] results: + return (original_result, results); + case float[] fv: + var result = ops.convert_to_tensor(fv[0]); + return (original_result, new Tensor[] { result }); + default: + return (original_result, new Tensor[0]); + } } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index a93c1653..9dd853d9 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -53,7 +53,7 @@ namespace Tensorflow.Operations return _op.outputs[0]; } - public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x, + public static Tensor[] _fused_batch_norm(Tensor x, Tensor scale, Tensor offset, Tensor mean, @@ -75,7 +75,12 @@ namespace Tensorflow.Operations is_training }); - return (_op.outputs[0], _op.outputs[1], _op.outputs[2]); + return _op.outputs; + } + + public static Tensor max_pool() + { + throw new NotImplementedException(""); } } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index bca74989..01af43b7 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -137,9 +137,9 @@ namespace Tensorflow return gen_array_ops.identity(data, name: name); } - public static (Tensor, Tensor) cond(Tensor pred, - Func<(Tensor, Tensor, Tensor)> true_fn = null, - Func<(Tensor, Tensor, Tensor)> false_fn = null, + public static Tensor[] cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, bool strict = false, string name = null) { @@ -158,20 +158,46 @@ namespace Tensorflow // Build the graph for the true branch in a new context. var context_t = new CondContext(pred, pivot_1, branch: 1); context_t.Enter(); - var res_t = context_t.BuildCondBranch(true_fn); + var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); context_t.Exit(); // Build the graph for the false branch in a new context. var context_f = new CondContext(pred, pivot_2, branch: 0); context_f.Enter(); - var res_f = context_f.BuildCondBranch(false_fn); + var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); context_f.Exit(); - var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 }; - var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 }; + var res_t_flat = res_t; + var res_f_flat = res_f; + var merges = zip(res_f_flat, res_t_flat) + .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) + .ToArray(); - return (p_2, p_1); + merges = _convert_flows_to_tensorarrays(orig_res_t, merges); + + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); + + return merges; + }); + } + + public static Tensor[] _convert_flows_to_tensorarrays(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) + { + // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); + return tensors_or_flows; + } + + public static Tensor merge(Tensor[] inputs, string name = null) + { + return with(ops.name_scope(name, "Merge", inputs), scope => + { + name = scope; + inputs = inputs.Select(inp => + ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) + .ToArray(); + return gen_control_flow_ops.merge(inputs, name).Item1; }); } diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs index faedfae4..21447c57 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow { - public class gen_control_flow_ops + public class gen_control_flow_ops : Python { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); @@ -21,5 +21,12 @@ namespace Tensorflow return (_op.outputs[0], _op.outputs[1]); } + + public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); + + return (_op.outputs[0], _op.outputs[1]); + } } } diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index 81515e18..c5de18da 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -46,7 +46,7 @@ namespace Tensorflow }); } - public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, + public static Tensor[] fused_batch_norm(Tensor x, RefVariable scale, RefVariable offset, Tensor mean, diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index 1a7beb72..b0382553 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -118,14 +118,6 @@ namespace Tensorflow { object obj = propertyDescriptor.GetValue(dyn); string name = propertyDescriptor.Name; - // avoid .net keyword - switch (name) - { - case "_ref_": - name = "ref"; - break; - } - dictionary.Add(name, obj); } return dictionary; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index cd3ac548..8d941d21 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -14,6 +14,7 @@ namespace Tensorflow public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y); public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y); public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y); + public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("Sub", x, y); public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y); public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y); @@ -48,7 +49,7 @@ namespace Tensorflow var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); - switch (name) + switch (name.ToLower()) { case "add": result = gen_math_ops.add(x1, y1, name: scope); diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index bee8e68e..567f5c9f 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -38,7 +38,8 @@ namespace Tensorflow { return MakeNdarray(tensor.op.get_attr("value") as TensorProto); } - throw new NotImplementedException("_ConstantValue"); + + return null; } public static NDArray MakeNdarray(TensorProto tensor) diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs index c0b23575..0d27227a 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs @@ -13,7 +13,8 @@ namespace Tensorflow public static Tensor operator -(RefVariable x, int y) => op_helper("sub", x, y); public static Tensor operator -(RefVariable x, float y) => op_helper("sub", x, y); public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y); - + public static Tensor operator -(RefVariable x, Tensor y) => op_helper("sub", x, y); + private static Tensor op_helper(string default_name, RefVariable x, T y) { var tensor1 = x.value(); diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 737d95b1..8e75c857 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -52,7 +52,7 @@ namespace Tensorflow bool use_locking = true, string name = null) { - var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { _ref_ = tensor, value, validate_shape, use_locking }); + var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref = tensor, value, validate_shape, use_locking }); var _result = _op.outputs; var _inputs_flat = _op.inputs; @@ -66,5 +66,15 @@ namespace Tensorflow return _result[0]; } + + public static Tensor assign_sub(RefVariable @ref, + Tensor value, + bool use_locking = false, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("AssignSub", name: name, args: new { @ref, value, use_locking }); + + return _op.outputs[0]; + } } } diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 0144a138..b2cb6082 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -24,5 +24,13 @@ namespace Tensorflow name: name, container: container, shared_name: shared_name); + + public static Tensor assign_sub(RefVariable @ref, + Tensor value, + bool use_locking = false, + string name = null) => gen_state_ops.assign_sub(@ref, + value, + use_locking: use_locking, + name: name); } } diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 61d527f6..9e1e72f2 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -47,6 +47,10 @@ namespace Tensorflow // Used to store v2 summary names. public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; + + // Key for control flow context. + public static string COND_CONTEXT = "cond_context"; + public static string WHILE_CONTEXT = "while_context"; } } } diff --git a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs index 05929d3d..d62e52c1 100644 --- a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs @@ -93,7 +93,11 @@ namespace TensorFlowNET.Examples.TextClassification if (max_pool) { // Max pooling - throw new NotImplementedException("conv_block"); + return tf.layers.max_pooling2d( + conv, + pool_size: new int[] { 3, 1 }, + strides: new int[] { 2, 1 }, + padding: "SAME"); } else {