diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index c63fe7fe..b10e1d77 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -30,5 +30,13 @@ namespace Tensorflow /// public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) => array_ops.transpose(a, perm, name, conjugate); + + public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) + => gen_array_ops.squeeze(input, axis, name); + + public static Tensor one_hot(Tensor indices, int depth) + { + throw new NotImplementedException("one_hot"); + } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index fb5421c3..c6455341 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -27,5 +27,13 @@ namespace Tensorflow default_name, values, auxiliary_name_scope); + + public static IInitializer truncated_normal_initializer(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.DtInvalid) => new TruncatedNormal(mean: mean, + stddev: stddev, + seed: seed, + dtype: dtype); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index ba543953..faf0d089 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -126,6 +126,26 @@ namespace Tensorflow return layer.apply(inputs); } + + public static Tensor dense(Tensor inputs, + int units, + IActivation activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + bool trainable = true, + string name = null, + bool? reuse = null) + { + if (bias_initializer == null) + bias_initializer = tf.zeros_initializer; + + var layer = new Dense(units, activation, + use_bias: use_bias, + kernel_initializer: kernel_initializer); + + return layer.apply(inputs); + } } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 9226ce63..ad548864 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -6,21 +6,26 @@ namespace Tensorflow { public static partial class tf { - public static Tensor add(Tensor a, Tensor b) => gen_math_ops.add(a, b); + public static Tensor add(Tensor a, Tensor b) + => gen_math_ops.add(a, b); - public static Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b); + public static Tensor sub(Tensor a, Tensor b) + => gen_math_ops.sub(a, b); - public static Tensor sqrt(Tensor a, string name = null) => gen_math_ops.sqrt(a, name); + public static Tensor sqrt(Tensor a, string name = null) + => gen_math_ops.sqrt(a, name); public static Tensor subtract(Tensor x, T[] y, string name = null) where T : struct => gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name); - public static Tensor multiply(Tensor x, Tensor y) => gen_math_ops.mul(x, y); + public static Tensor multiply(Tensor x, Tensor y) + => gen_math_ops.mul(x, y); public static Tensor divide(Tensor x, T[] y, string name = null) where T : struct => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); - public static Tensor pow(T1 x, T2 y) => gen_math_ops.pow(x, y); + public static Tensor pow(T1 x, T2 y) + => gen_math_ops.pow(x, y); /// /// Computes the sum of elements across dimensions of a tensor. @@ -28,9 +33,13 @@ namespace Tensorflow /// /// /// - public static Tensor reduce_sum(Tensor input, int[] axis = null) => math_ops.reduce_sum(input); + public static Tensor reduce_sum(Tensor input, int[] axis = null) + => math_ops.reduce_sum(input); public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) => math_ops.cast(x, dtype, name); + + public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64) + => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 5a940afe..87b646d7 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -42,7 +42,10 @@ namespace Tensorflow is_training: is_training, name: name); - public static Tensor max_pool() => gen_nn_ops.max_pool(); + public static IPoolFunction max_pool => new MaxPoolFunction(); + + public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) + => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs index 41861968..30098581 100644 --- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs +++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs @@ -10,5 +10,8 @@ namespace Tensorflow Tensor shape, string name = null) => gen_array_ops.reshape(tensor, shape, name); + public static Tensor reshape(Tensor tensor, + int[] shape, + string name = null) => gen_array_ops.reshape(tensor, shape, name); } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index bfd96d6a..d84c31e1 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -10,16 +10,19 @@ namespace Tensorflow.Keras.Engine public class InputSpec { public int ndim; + public int? min_ndim; Dictionary axes; - public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, + public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, int? ndim = null, + int? min_ndim = null, Dictionary axes = null) { this.ndim = ndim.Value; if (axes == null) axes = new Dictionary(); this.axes = axes; + this.min_ndim = min_ndim; } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 7e722ff2..12df6b4d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -122,7 +122,7 @@ namespace Tensorflow.Keras.Engine protected virtual void build(TensorShape input_shape) { - throw new NotImplementedException("Layer.build"); + built = true; } protected virtual RefVariable add_weight(string name, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs new file mode 100644 index 00000000..7bf63413 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Operations.Activation; + +namespace Tensorflow.Keras.Layers +{ + public class Dense : Tensorflow.Layers.Layer + { + protected int uints; + protected IActivation activation; + protected bool use_bias; + protected IInitializer kernel_initializer; + protected IInitializer bias_initializer; + + public Dense(int units, + IActivation activation, + bool use_bias = true, + bool trainable = false, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null) : base(trainable: trainable) + { + this.uints = units; + this.activation = activation; + this.use_bias = use_bias; + this.kernel_initializer = kernel_initializer; + this.bias_initializer = bias_initializer; + this.supports_masking = true; + this.input_spec = new InputSpec(min_ndim: 2); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs b/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs new file mode 100644 index 00000000..07544f10 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public interface IPoolFunction + { + Tensor Apply(Tensor value, + int[] ksize, + int[] strides, + string padding, + string data_format = "NHWC", + string name = null); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 1bdb769b..69c4d65c 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -8,14 +8,14 @@ namespace Tensorflow.Keras.Layers { public class Pooling2D : Tensorflow.Layers.Layer { - private Func pool_function; + private IPoolFunction pool_function; private int[] pool_size; private int[] strides; private string padding; private string data_format; private InputSpec input_spec; - public Pooling2D(Func pool_function, + public Pooling2D(IPoolFunction pool_function, int[] pool_size, int[] strides, string padding = "valid", @@ -29,5 +29,29 @@ namespace Tensorflow.Keras.Layers this.data_format = conv_utils.normalize_data_format(data_format); this.input_spec = new InputSpec(ndim: 4); } + + protected override Tensor call(Tensor inputs, Tensor training = null) + { + int[] pool_shape; + if (data_format == "channels_last") + { + pool_shape = new int[] { 1, pool_size[0], pool_size[1], 1 }; + strides = new int[] { 1, strides[0], strides[1], 1 }; + } + else + { + pool_shape = new int[] { 1, 1, pool_size[0], pool_size[1] }; + strides = new int[] { 1, 1, strides[0], strides[1] }; + } + + var outputs = pool_function.Apply( + inputs, + ksize: pool_shape, + strides: strides, + padding: padding.ToUpper(), + data_format: conv_utils.convert_data_format(data_format, 4)); + + return outputs; + } } } diff --git a/src/TensorFlowNET.Core/Layers/Dense.cs b/src/TensorFlowNET.Core/Layers/Dense.cs new file mode 100644 index 00000000..e2868877 --- /dev/null +++ b/src/TensorFlowNET.Core/Layers/Dense.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Activation; + +namespace Tensorflow.Layers +{ + public class Dense : Keras.Layers.Dense + { + public Dense(int units, + IActivation activation, + bool use_bias = true, + bool trainable = false, + IInitializer kernel_initializer = null) : base(units, + activation, + use_bias: use_bias, + trainable: trainable, + kernel_initializer: kernel_initializer) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index 4c0a7cee..b4f7197b 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -11,9 +11,9 @@ namespace Tensorflow.Operations.Initializers private int? seed; private TF_DataType dtype; - public TruncatedNormal(float mean = 0.0f, - float stddev = 1.0f, - int? seed = null, + public TruncatedNormal(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, TF_DataType dtype = TF_DataType.TF_FLOAT) { this.mean = mean; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs b/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs new file mode 100644 index 00000000..5f15706e --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public class MaxPoolFunction : Python, IPoolFunction + { + public Tensor Apply(Tensor value, + int[] ksize, + int[] strides, + string padding, + string data_format = "NHWC", + string name = null) + { + return with(ops.name_scope(name, "MaxPool", new { value }), scope => { + + value = ops.convert_to_tensor(value, name: "input"); + return gen_nn_ops.max_pool( + value, + ksize: ksize, + strides: strides, + padding: padding, + data_format: data_format, + name: name); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 9dd853d9..4f8cd7ff 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -78,9 +78,35 @@ namespace Tensorflow.Operations return _op.outputs; } - public static Tensor max_pool() + public static Tensor max_pool(Tensor input, + int[] ksize, + int[] strides, + string padding, + string data_format = "NHWC", + string name = null) { - throw new NotImplementedException(""); + var _op = _op_def_lib._apply_op_helper("MaxPool", name: name, args: new + { + input, + ksize, + strides, + padding, + data_format, + }); + + return _op.outputs[0]; + } + + public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null) + { + var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new + { + input, + k, + sorted + }); + + return _op.outputs; } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index de8cc9d5..9260d9fe 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -115,6 +115,12 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor reshape(Tensor tensor, int[] shape, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape }); + return _op.outputs[0]; + } + public static Tensor where() { throw new NotImplementedException("where"); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 9d1e8788..1033f76c 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -221,5 +221,20 @@ namespace Tensorflow return _op.outputs[0]; } + + /// + /// Returns the index with the largest value across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ArgMax", name, new { input, dimension, output_type }); + + return _op.outputs[0]; + } } } diff --git a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs index d62e52c1..cce3bbef 100644 --- a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs @@ -14,6 +14,7 @@ namespace TensorFlowNET.Examples.TextClassification private int[] num_blocks; private float learning_rate; private IInitializer cnn_initializer; + private IInitializer fc_initializer; private Tensor x; private Tensor y; private Tensor is_training; @@ -30,6 +31,8 @@ namespace TensorFlowNET.Examples.TextClassification num_blocks = new int[] { 2, 2, 2, 2 }; learning_rate = 0.001f; cnn_initializer = tf.keras.initializers.he_normal(); + fc_initializer = tf.truncated_normal_initializer(stddev: 0.05f); + x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training"); @@ -46,6 +49,14 @@ namespace TensorFlowNET.Examples.TextClassification Tensor conv0 = null; Tensor conv1 = null; + Tensor conv2 = null; + Tensor conv3 = null; + Tensor conv4 = null; + Tensor h_flat = null; + Tensor fc1_out = null; + Tensor fc2_out = null; + Tensor logits = null; + Tensor predictions = null; // First Convolution Layer with(tf.variable_scope("conv-0"), delegate @@ -62,7 +73,50 @@ namespace TensorFlowNET.Examples.TextClassification with(tf.name_scope("conv-block-1"), delegate { conv1 = conv_block(conv0, 1); }); - + + with(tf.name_scope("conv-block-2"), delegate { + conv2 = conv_block(conv1, 2); + }); + + with(tf.name_scope("conv-block-3"), delegate { + conv3 = conv_block(conv2, 3); + }); + + with(tf.name_scope("conv-block-4"), delegate + { + conv4 = conv_block(conv3, 4, max_pool: false); + }); + + // ============= k-max Pooling ============= + with(tf.name_scope("k-max-pooling"), delegate + { + var h = tf.transpose(tf.squeeze(conv4, new int[] { -1 }), new int[] { 0, 2, 1 }); + var top_k = tf.nn.top_k(h, k: 8, sorted: false)[0]; + h_flat = tf.reshape(top_k, new int[] { -1, 512 * 8 }); + }); + + // ============= Fully Connected Layers ============= + with(tf.name_scope("fc-1"), scope => + { + fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer); + }); + + with(tf.name_scope("fc-2"), scope => + { + fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer); + }); + + with(tf.name_scope("fc-3"), scope => + { + logits = tf.layers.dense(fc2_out, num_class, activation: null, kernel_initializer: fc_initializer); + predictions = tf.argmax(logits, -1, output_type: tf.int32); + }); + + // ============= Loss and Accuracy ============= + with(tf.name_scope("loss"), delegate + { + var y_one_hot = tf.one_hot(y, num_class); + }); } private Tensor conv_block(Tensor input, int i, bool max_pool = true)