diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index c63fe7fe..b10e1d77 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -30,5 +30,13 @@ namespace Tensorflow
///
public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false)
=> array_ops.transpose(a, perm, name, conjugate);
+
+ public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1)
+ => gen_array_ops.squeeze(input, axis, name);
+
+ public static Tensor one_hot(Tensor indices, int depth)
+ {
+ throw new NotImplementedException("one_hot");
+ }
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs
index fb5421c3..c6455341 100644
--- a/src/TensorFlowNET.Core/APIs/tf.init.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.init.cs
@@ -27,5 +27,13 @@ namespace Tensorflow
default_name,
values,
auxiliary_name_scope);
+
+ public static IInitializer truncated_normal_initializer(float mean = 0.0f,
+ float stddev = 1.0f,
+ int? seed = null,
+ TF_DataType dtype = TF_DataType.DtInvalid) => new TruncatedNormal(mean: mean,
+ stddev: stddev,
+ seed: seed,
+ dtype: dtype);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs
index ba543953..faf0d089 100644
--- a/src/TensorFlowNET.Core/APIs/tf.layers.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs
@@ -126,6 +126,26 @@ namespace Tensorflow
return layer.apply(inputs);
}
+
+ public static Tensor dense(Tensor inputs,
+ int units,
+ IActivation activation = null,
+ bool use_bias = true,
+ IInitializer kernel_initializer = null,
+ IInitializer bias_initializer = null,
+ bool trainable = true,
+ string name = null,
+ bool? reuse = null)
+ {
+ if (bias_initializer == null)
+ bias_initializer = tf.zeros_initializer;
+
+ var layer = new Dense(units, activation,
+ use_bias: use_bias,
+ kernel_initializer: kernel_initializer);
+
+ return layer.apply(inputs);
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 9226ce63..ad548864 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -6,21 +6,26 @@ namespace Tensorflow
{
public static partial class tf
{
- public static Tensor add(Tensor a, Tensor b) => gen_math_ops.add(a, b);
+ public static Tensor add(Tensor a, Tensor b)
+ => gen_math_ops.add(a, b);
- public static Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b);
+ public static Tensor sub(Tensor a, Tensor b)
+ => gen_math_ops.sub(a, b);
- public static Tensor sqrt(Tensor a, string name = null) => gen_math_ops.sqrt(a, name);
+ public static Tensor sqrt(Tensor a, string name = null)
+ => gen_math_ops.sqrt(a, name);
public static Tensor subtract(Tensor x, T[] y, string name = null) where T : struct
=> gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name);
- public static Tensor multiply(Tensor x, Tensor y) => gen_math_ops.mul(x, y);
+ public static Tensor multiply(Tensor x, Tensor y)
+ => gen_math_ops.mul(x, y);
public static Tensor divide(Tensor x, T[] y, string name = null) where T : struct
=> x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");
- public static Tensor pow(T1 x, T2 y) => gen_math_ops.pow(x, y);
+ public static Tensor pow(T1 x, T2 y)
+ => gen_math_ops.pow(x, y);
///
/// Computes the sum of elements across dimensions of a tensor.
@@ -28,9 +33,13 @@ namespace Tensorflow
///
///
///
- public static Tensor reduce_sum(Tensor input, int[] axis = null) => math_ops.reduce_sum(input);
+ public static Tensor reduce_sum(Tensor input, int[] axis = null)
+ => math_ops.reduce_sum(input);
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
=> math_ops.cast(x, dtype, name);
+
+ public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
+ => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 5a940afe..87b646d7 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -42,7 +42,10 @@ namespace Tensorflow
is_training: is_training,
name: name);
- public static Tensor max_pool() => gen_nn_ops.max_pool();
+ public static IPoolFunction max_pool => new MaxPoolFunction();
+
+ public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
+ => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);
}
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
index 41861968..30098581 100644
--- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
@@ -10,5 +10,8 @@ namespace Tensorflow
Tensor shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name);
+ public static Tensor reshape(Tensor tensor,
+ int[] shape,
+ string name = null) => gen_array_ops.reshape(tensor, shape, name);
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
index bfd96d6a..d84c31e1 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
@@ -10,16 +10,19 @@ namespace Tensorflow.Keras.Engine
public class InputSpec
{
public int ndim;
+ public int? min_ndim;
Dictionary axes;
- public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
+ public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
int? ndim = null,
+ int? min_ndim = null,
Dictionary axes = null)
{
this.ndim = ndim.Value;
if (axes == null)
axes = new Dictionary();
this.axes = axes;
+ this.min_ndim = min_ndim;
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
index 7e722ff2..12df6b4d 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
@@ -122,7 +122,7 @@ namespace Tensorflow.Keras.Engine
protected virtual void build(TensorShape input_shape)
{
- throw new NotImplementedException("Layer.build");
+ built = true;
}
protected virtual RefVariable add_weight(string name,
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs
new file mode 100644
index 00000000..7bf63413
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs
@@ -0,0 +1,33 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Operations.Activation;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class Dense : Tensorflow.Layers.Layer
+ {
+ protected int uints;
+ protected IActivation activation;
+ protected bool use_bias;
+ protected IInitializer kernel_initializer;
+ protected IInitializer bias_initializer;
+
+ public Dense(int units,
+ IActivation activation,
+ bool use_bias = true,
+ bool trainable = false,
+ IInitializer kernel_initializer = null,
+ IInitializer bias_initializer = null) : base(trainable: trainable)
+ {
+ this.uints = units;
+ this.activation = activation;
+ this.use_bias = use_bias;
+ this.kernel_initializer = kernel_initializer;
+ this.bias_initializer = bias_initializer;
+ this.supports_masking = true;
+ this.input_spec = new InputSpec(min_ndim: 2);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs b/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs
new file mode 100644
index 00000000..07544f10
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public interface IPoolFunction
+ {
+ Tensor Apply(Tensor value,
+ int[] ksize,
+ int[] strides,
+ string padding,
+ string data_format = "NHWC",
+ string name = null);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
index 1bdb769b..69c4d65c 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
@@ -8,14 +8,14 @@ namespace Tensorflow.Keras.Layers
{
public class Pooling2D : Tensorflow.Layers.Layer
{
- private Func pool_function;
+ private IPoolFunction pool_function;
private int[] pool_size;
private int[] strides;
private string padding;
private string data_format;
private InputSpec input_spec;
- public Pooling2D(Func pool_function,
+ public Pooling2D(IPoolFunction pool_function,
int[] pool_size,
int[] strides,
string padding = "valid",
@@ -29,5 +29,29 @@ namespace Tensorflow.Keras.Layers
this.data_format = conv_utils.normalize_data_format(data_format);
this.input_spec = new InputSpec(ndim: 4);
}
+
+ protected override Tensor call(Tensor inputs, Tensor training = null)
+ {
+ int[] pool_shape;
+ if (data_format == "channels_last")
+ {
+ pool_shape = new int[] { 1, pool_size[0], pool_size[1], 1 };
+ strides = new int[] { 1, strides[0], strides[1], 1 };
+ }
+ else
+ {
+ pool_shape = new int[] { 1, 1, pool_size[0], pool_size[1] };
+ strides = new int[] { 1, 1, strides[0], strides[1] };
+ }
+
+ var outputs = pool_function.Apply(
+ inputs,
+ ksize: pool_shape,
+ strides: strides,
+ padding: padding.ToUpper(),
+ data_format: conv_utils.convert_data_format(data_format, 4));
+
+ return outputs;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Layers/Dense.cs b/src/TensorFlowNET.Core/Layers/Dense.cs
new file mode 100644
index 00000000..e2868877
--- /dev/null
+++ b/src/TensorFlowNET.Core/Layers/Dense.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Operations.Activation;
+
+namespace Tensorflow.Layers
+{
+ public class Dense : Keras.Layers.Dense
+ {
+ public Dense(int units,
+ IActivation activation,
+ bool use_bias = true,
+ bool trainable = false,
+ IInitializer kernel_initializer = null) : base(units,
+ activation,
+ use_bias: use_bias,
+ trainable: trainable,
+ kernel_initializer: kernel_initializer)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
index 4c0a7cee..b4f7197b 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
@@ -11,9 +11,9 @@ namespace Tensorflow.Operations.Initializers
private int? seed;
private TF_DataType dtype;
- public TruncatedNormal(float mean = 0.0f,
- float stddev = 1.0f,
- int? seed = null,
+ public TruncatedNormal(float mean = 0.0f,
+ float stddev = 1.0f,
+ int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.mean = mean;
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs b/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs
new file mode 100644
index 00000000..5f15706e
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations
+{
+ public class MaxPoolFunction : Python, IPoolFunction
+ {
+ public Tensor Apply(Tensor value,
+ int[] ksize,
+ int[] strides,
+ string padding,
+ string data_format = "NHWC",
+ string name = null)
+ {
+ return with(ops.name_scope(name, "MaxPool", new { value }), scope => {
+
+ value = ops.convert_to_tensor(value, name: "input");
+ return gen_nn_ops.max_pool(
+ value,
+ ksize: ksize,
+ strides: strides,
+ padding: padding,
+ data_format: data_format,
+ name: name);
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 9dd853d9..4f8cd7ff 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -78,9 +78,35 @@ namespace Tensorflow.Operations
return _op.outputs;
}
- public static Tensor max_pool()
+ public static Tensor max_pool(Tensor input,
+ int[] ksize,
+ int[] strides,
+ string padding,
+ string data_format = "NHWC",
+ string name = null)
{
- throw new NotImplementedException("");
+ var _op = _op_def_lib._apply_op_helper("MaxPool", name: name, args: new
+ {
+ input,
+ ksize,
+ strides,
+ padding,
+ data_format,
+ });
+
+ return _op.outputs[0];
+ }
+
+ public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new
+ {
+ input,
+ k,
+ sorted
+ });
+
+ return _op.outputs;
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index de8cc9d5..9260d9fe 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -115,6 +115,12 @@ namespace Tensorflow
return _op.outputs[0];
}
+ public static Tensor reshape(Tensor tensor, int[] shape, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape });
+ return _op.outputs[0];
+ }
+
public static Tensor where()
{
throw new NotImplementedException("where");
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 9d1e8788..1033f76c 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -221,5 +221,20 @@ namespace Tensorflow
return _op.outputs[0];
}
+
+ ///
+ /// Returns the index with the largest value across dimensions of a tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("ArgMax", name, new { input, dimension, output_type });
+
+ return _op.outputs[0];
+ }
}
}
diff --git a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
index d62e52c1..cce3bbef 100644
--- a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
+++ b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
@@ -14,6 +14,7 @@ namespace TensorFlowNET.Examples.TextClassification
private int[] num_blocks;
private float learning_rate;
private IInitializer cnn_initializer;
+ private IInitializer fc_initializer;
private Tensor x;
private Tensor y;
private Tensor is_training;
@@ -30,6 +31,8 @@ namespace TensorFlowNET.Examples.TextClassification
num_blocks = new int[] { 2, 2, 2, 2 };
learning_rate = 0.001f;
cnn_initializer = tf.keras.initializers.he_normal();
+ fc_initializer = tf.truncated_normal_initializer(stddev: 0.05f);
+
x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training");
@@ -46,6 +49,14 @@ namespace TensorFlowNET.Examples.TextClassification
Tensor conv0 = null;
Tensor conv1 = null;
+ Tensor conv2 = null;
+ Tensor conv3 = null;
+ Tensor conv4 = null;
+ Tensor h_flat = null;
+ Tensor fc1_out = null;
+ Tensor fc2_out = null;
+ Tensor logits = null;
+ Tensor predictions = null;
// First Convolution Layer
with(tf.variable_scope("conv-0"), delegate
@@ -62,7 +73,50 @@ namespace TensorFlowNET.Examples.TextClassification
with(tf.name_scope("conv-block-1"), delegate {
conv1 = conv_block(conv0, 1);
});
-
+
+ with(tf.name_scope("conv-block-2"), delegate {
+ conv2 = conv_block(conv1, 2);
+ });
+
+ with(tf.name_scope("conv-block-3"), delegate {
+ conv3 = conv_block(conv2, 3);
+ });
+
+ with(tf.name_scope("conv-block-4"), delegate
+ {
+ conv4 = conv_block(conv3, 4, max_pool: false);
+ });
+
+ // ============= k-max Pooling =============
+ with(tf.name_scope("k-max-pooling"), delegate
+ {
+ var h = tf.transpose(tf.squeeze(conv4, new int[] { -1 }), new int[] { 0, 2, 1 });
+ var top_k = tf.nn.top_k(h, k: 8, sorted: false)[0];
+ h_flat = tf.reshape(top_k, new int[] { -1, 512 * 8 });
+ });
+
+ // ============= Fully Connected Layers =============
+ with(tf.name_scope("fc-1"), scope =>
+ {
+ fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer);
+ });
+
+ with(tf.name_scope("fc-2"), scope =>
+ {
+ fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer);
+ });
+
+ with(tf.name_scope("fc-3"), scope =>
+ {
+ logits = tf.layers.dense(fc2_out, num_class, activation: null, kernel_initializer: fc_initializer);
+ predictions = tf.argmax(logits, -1, output_type: tf.int32);
+ });
+
+ // ============= Loss and Accuracy =============
+ with(tf.name_scope("loss"), delegate
+ {
+ var y_one_hot = tf.one_hot(y, num_class);
+ });
}
private Tensor conv_block(Tensor input, int i, bool max_pool = true)