diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs
index 089dd8a5..9f989bc5 100644
--- a/src/TensorFlowNET.Core/APIs/tf.layers.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs
@@ -15,6 +15,8 @@
******************************************************************************/
using System.Collections.Generic;
+using System.Linq;
+using NumSharp;
using Tensorflow.Keras.Layers;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -144,6 +146,20 @@ namespace Tensorflow
return layer.apply(inputs);
}
+ ///
+ /// Densely-connected layer class. aka fully-connected
+ /// `outputs = activation(inputs * kernel + bias)`
+ ///
+ ///
+ /// Python integer, dimensionality of the output space.
+ ///
+ /// Boolean, whether the layer uses a bias.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
public Tensor dense(Tensor inputs,
int units,
IActivation activation = null,
@@ -160,7 +176,8 @@ namespace Tensorflow
var layer = new Dense(units, activation,
use_bias: use_bias,
bias_initializer: bias_initializer,
- kernel_initializer: kernel_initializer);
+ kernel_initializer: kernel_initializer,
+ trainable: trainable);
return layer.apply(inputs);
}
@@ -182,6 +199,7 @@ namespace Tensorflow
string name = null,
string data_format = "channels_last")
{
+ var input_shape = inputs.shape;
if (inputs.shape.Length == 0)
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");
@@ -193,9 +211,24 @@ namespace Tensorflow
inputs = array_ops.transpose(inputs, premutation.ToArray());
}
- var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1});
- ret.set_shape(new int[] {inputs.shape[0], -1});
+ var ret = array_ops.reshape(inputs, compute_output_shape(input_shape));
+ //ret.set_shape(compute_output_shape(ret.shape));
return ret;
+
+ int[] compute_output_shape(int[] inputshape)
+ {
+ if (inputshape == null || inputshape.Length == 0)
+ inputshape = new int[] {1};
+
+ if (inputshape.Skip(1).All(d => d > 0))
+ {
+ int[] output_shape = new int[2];
+ output_shape[0] = inputshape[0];
+ output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc*rhs); //calculate size of all the rest dimensions
+ return output_shape;
+ } else
+ return new int[] {inputshape[0], -1}; //-1 == Binding.None
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index ec081cc4..985e3f73 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using Tensorflow.Operations;
+
namespace Tensorflow
{
public partial class tensorflow
@@ -211,6 +213,36 @@ namespace Tensorflow
///
public Tensor _clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null)
=> gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max);
+
+ ///
+ /// Clips tensor values to a specified min and max.
+ ///
+ ///
+ /// A Tensor.
+ ///
+ ///
+ /// A 0-D (scalar) Tensor, or a Tensor with the same shape
+ /// as t. The minimum value to clip by.
+ ///
+ ///
+ /// A 0-D (scalar) Tensor, or a Tensor with the same shape
+ /// as t. The maximum value to clip by.
+ ///
+ ///
+ /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ClipByValue'.
+ ///
+ ///
+ /// A clipped Tensor with the same shape as input 't'.
+ /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result.
+ ///
+ ///
+ /// Given a tensor t, this operation returns a tensor of the same type and
+ /// shape as t with its values clipped to clip_value_min and clip_value_max.
+ /// Any values less than clip_value_min are set to clip_value_min. Any values
+ /// greater than clip_value_max are set to clip_value_max.
+ ///
+ public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue")
+ => gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name);
public Tensor sub(Tensor a, Tensor b)
=> gen_math_ops.sub(a, b);
diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs
index b553095e..2052de93 100644
--- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs
@@ -18,8 +18,8 @@ namespace Tensorflow
{
public partial class tensorflow
{
- public Tensor convert_to_tensor(object value,
- string name = null) => ops.convert_to_tensor(value, name: name);
+ public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
+ => ops.convert_to_tensor(value, dtype, name, preferred_dtype);
public Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides = null,
int begin_mask = 0,
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index ff383642..def78327 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -21,6 +21,7 @@ using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
+using NumSharp.Utilities;
namespace Tensorflow
{
@@ -29,9 +30,37 @@ namespace Tensorflow
///
public static partial class Binding
{
+ private static string _tostring(object obj)
+ {
+ switch (obj)
+ {
+ case NDArray nd:
+ return nd.ToString(false);
+ case Array arr:
+ if (arr.Rank!=1 || arr.GetType().GetElementType()?.IsArray == true)
+ arr = Arrays.Flatten(arr);
+ var objs = toObjectArray(arr);
+ return $"[{string.Join(", ", objs.Select(_tostring))}]";
+ default:
+ return obj?.ToString() ?? "null";
+ }
+
+ object[] toObjectArray(Array arr)
+ {
+ var len = arr.LongLength;
+ var ret = new object[len];
+ for (long i = 0; i < len; i++)
+ {
+ ret[i] = arr.GetValue(i);
+ }
+
+ return ret;
+ }
+ }
+
public static void print(object obj)
{
- Console.WriteLine(obj.ToString());
+ Console.WriteLine(_tostring(obj));
}
public static int len(object a)
diff --git a/src/TensorFlowNET.Core/Binding.cs b/src/TensorFlowNET.Core/Binding.cs
index f443f2eb..e3136f83 100644
--- a/src/TensorFlowNET.Core/Binding.cs
+++ b/src/TensorFlowNET.Core/Binding.cs
@@ -7,5 +7,16 @@ namespace Tensorflow
public static partial class Binding
{
public static tensorflow tf { get; } = New();
+
+ ///
+ /// Alias to null, similar to python's None.
+ /// For TensorShape, please use Unknown
+ ///
+ public static readonly object None = null;
+
+ ///
+ /// Used for TensorShape None
+ ///
+ public static readonly int Unknown = -1;
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs b/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs
index 80e1c305..788adda4 100644
--- a/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs
+++ b/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs
@@ -14,20 +14,192 @@
limitations under the License.
******************************************************************************/
+using System;
+using static Tensorflow.Binding;
+
namespace Tensorflow.Operations.Activation
{
+ public class sigmoid : IActivation
+ {
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ return tf.sigmoid(x);
+ }
+ }
+
+ public class tanh : IActivation
+ {
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ return tf.tanh(x);
+ }
+ }
+
+ public class leakyrelu : IActivation
+ {
+ private readonly float _alpha;
+
+ public leakyrelu(float alpha = 0.3f) {
+ _alpha = alpha;
+ }
+
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ return nn_ops.leaky_relu(x, _alpha);
+ }
+ }
+
+ public class elu : IActivation
+ {
+ private readonly float _alpha;
+
+ public elu(float alpha = 0.1f)
+ {
+ _alpha = alpha;
+ }
+
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ var res = gen_ops.elu(x);
+ if (Math.Abs(_alpha - 0.1f) < 0.00001f)
+ {
+ return res;
+ }
+
+ return array_ops.@where(x > 0, res, _alpha * res);
+ }
+ }
+
+ public class softmax : IActivation
+ {
+ private readonly int _axis;
+
+ /// Initializes a new instance of the class.
+ public softmax(int axis = -1)
+ {
+ _axis = axis;
+ }
+
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ return nn_ops.softmax(x, _axis);
+ }
+ }
+
+ public class softplus : IActivation
+ {
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ return gen_ops.softplus(x);
+ }
+ }
+
+ public class softsign : IActivation
+ {
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ return gen_ops.softsign(x);
+ }
+ }
+
+ public class linear : IActivation
+ {
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ return x;
+ }
+ }
+
+
+ public class exponential : IActivation
+ {
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ return tf.exp(x, name: name);
+ }
+ }
+
+
public class relu : IActivation
{
- public Tensor Activate(Tensor features, string name = null)
+ private readonly float _threshold;
+ private readonly float _alpha;
+ private readonly float? _maxValue;
+
+ public relu(float threshold = 0f, float alpha = 0.2f, float? max_value = null)
+ {
+ _threshold = threshold;
+ _alpha = alpha;
+ _maxValue = max_value;
+ }
+
+ public Tensor Activate(Tensor x, string name = null)
{
- OpDefLibrary _op_def_lib = new OpDefLibrary();
+ //based on keras/backend.py
+ if (Math.Abs(_alpha) > 0.000001f)
+ {
+ if (!_maxValue.HasValue && Math.Abs(_threshold) < 0.0001)
+ {
+ return nn_ops.leaky_relu(x, _alpha);
+ }
+ }
+
+ Tensor negative_part;
+ if (Math.Abs(_threshold) > 0.000001f)
+ {
+ negative_part = gen_ops.relu(-x + _threshold);
+ } else
+ {
+ negative_part = gen_ops.relu(-x + _threshold);
+ }
+
+ if (Math.Abs(_threshold) > 0.000001f)
+ {
+ x = x * math_ops.cast(tf.greater(x, _threshold), TF_DataType.TF_FLOAT);
+ } else if (Math.Abs(_maxValue.Value - 6f) < 0.0001f)
+ {
+ x = gen_ops.relu6(x);
+ } else
+ {
+ x = gen_ops.relu(x);
+ }
+
+ bool clip_max = _maxValue.HasValue;
+ if (clip_max)
+ {
+ Tensor maxval = constant_op.constant(_maxValue, x.dtype.as_base_dtype());
+ var zero = constant_op.constant(0.0f, x.dtype.as_base_dtype());
+ x = gen_ops.clip_by_value(x, zero, maxval);
+ }
- var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new
+ if (Math.Abs(_alpha) > 0.00001)
{
- features
- });
+ var a = constant_op.constant(_alpha, x.dtype.as_base_dtype());
+ x -= a * negative_part;
+ }
- return _op.outputs[0];
+ return x;
+ }
+ }
+
+ public class selu : IActivation
+ {
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ const float alpha = 1.6732632423543772848170429916717f;
+ const float scale = 1.0507009873554804934193349852946f;
+ return scale * new elu(alpha).Activate(x, name);
+ }
+ }
+
+ public class hard_sigmoid : IActivation
+ {
+ public Tensor Activate(Tensor x, string name = null)
+ {
+ x = (0.2 * x) + 0.5;
+ var zero = tf.convert_to_tensor(0.0f, x.dtype.as_base_dtype());
+ var one = tf.convert_to_tensor(1.0f, x.dtype.as_base_dtype());
+ return tf.clip_by_value(x, zero, one);
}
}
-}
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 1fc95927..853255cb 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -3,6 +3,7 @@ using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
+using NumSharp.Utilities;
namespace Tensorflow
{
@@ -32,7 +33,23 @@ namespace Tensorflow
///
/// Returns the size this shape represents.
///
- public int size => shape.Size;
+ public int size
+ {
+ get
+ {
+ var dims = shape.Dimensions;
+ var computed = 1;
+ for (int i = 0; i < dims.Length; i++)
+ {
+ var val = dims[i];
+ if (val <= 0)
+ continue;
+ computed *= val;
+ }
+
+ return computed;
+ }
+ }
public TensorShape(TensorShapeProto proto)
{
@@ -59,12 +76,30 @@ namespace Tensorflow
switch (dims.Length)
{
case 0: shape = new Shape(new int[0]); break;
- case 1: shape = Shape.Vector((int) dims[0]); break;
+ case 1: shape = Shape.Vector((int)dims[0]); break;
case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
default: shape = new Shape(dims); break;
}
}
+ public TensorShape(int[][] dims)
+ {
+ if(dims.Length == 1)
+ {
+ switch (dims[0].Length)
+ {
+ case 0: shape = new Shape(new int[0]); break;
+ case 1: shape = Shape.Vector((int)dims[0][0]); break;
+ case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break;
+ default: shape = new Shape(dims[0]); break;
+ }
+ }
+ else
+ {
+ throw new NotImplementedException("TensorShape int[][] dims");
+ }
+ }
+
///
///
///
@@ -188,6 +223,11 @@ namespace Tensorflow
public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
-
+
+ public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
+
+ public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
}
}
diff --git a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs
new file mode 100644
index 00000000..efa7def3
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs
@@ -0,0 +1,60 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using static Tensorflow.Binding;
+
+namespace TensorFlowNET.UnitTest
+{
+ [TestClass]
+ public class TensorShapeTest
+ {
+ [TestMethod]
+ public void Case1()
+ {
+ int a = 2;
+ int b = 3;
+ var dims = new [] { Unknown, a, b};
+ new TensorShape(dims).GetPrivate("shape").Should().BeShaped(-1, 2, 3);
+ }
+
+ [TestMethod]
+ public void Case2()
+ {
+ int a = 2;
+ int b = 3;
+ var dims = new[] { Unknown, a, b};
+ new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, 3);
+ }
+
+ [TestMethod]
+ public void Case3()
+ {
+ int a = 2;
+ int b = Unknown;
+ var dims = new [] { Unknown, a, b};
+ new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, -1);
+ }
+
+ [TestMethod]
+ public void Case4()
+ {
+ TensorShape shape = (Unknown, Unknown);
+ shape.GetPrivate("shape").Should().BeShaped(-1, -1);
+ }
+
+ [TestMethod]
+ public void Case5()
+ {
+ TensorShape shape = (1, Unknown, 3);
+ shape.GetPrivate("shape").Should().BeShaped(1, -1, 3);
+ }
+
+ [TestMethod]
+ public void Case6()
+ {
+ TensorShape shape = (Unknown, 1, 2, 3, Unknown);
+ shape.GetPrivate("shape").Should().BeShaped(-1, 1, 2, 3, -1);
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs
index d533f128..fa8ec792 100644
--- a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs
+++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs
@@ -36,5 +36,23 @@ namespace TensorFlowNET.UnitTest.layers_test
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw();
}
+
+ [TestMethod]
+ public void Case4()
+ {
+ var sess = tf.Session().as_default();
+
+ var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2));
+ sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
+ }
+
+ [TestMethod]
+ public void Case5()
+ {
+ var sess = tf.Session().as_default();
+
+ var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2));
+ sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
+ }
}
}
\ No newline at end of file