diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 089dd8a5..9f989bc5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System.Collections.Generic; +using System.Linq; +using NumSharp; using Tensorflow.Keras.Layers; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -144,6 +146,20 @@ namespace Tensorflow return layer.apply(inputs); } + /// + /// Densely-connected layer class. aka fully-connected

+ /// `outputs = activation(inputs * kernel + bias)` + ///
+ /// + /// Python integer, dimensionality of the output space. + /// + /// Boolean, whether the layer uses a bias. + /// + /// + /// + /// + /// + /// public Tensor dense(Tensor inputs, int units, IActivation activation = null, @@ -160,7 +176,8 @@ namespace Tensorflow var layer = new Dense(units, activation, use_bias: use_bias, bias_initializer: bias_initializer, - kernel_initializer: kernel_initializer); + kernel_initializer: kernel_initializer, + trainable: trainable); return layer.apply(inputs); } @@ -182,6 +199,7 @@ namespace Tensorflow string name = null, string data_format = "channels_last") { + var input_shape = inputs.shape; if (inputs.shape.Length == 0) throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); @@ -193,9 +211,24 @@ namespace Tensorflow inputs = array_ops.transpose(inputs, premutation.ToArray()); } - var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1}); - ret.set_shape(new int[] {inputs.shape[0], -1}); + var ret = array_ops.reshape(inputs, compute_output_shape(input_shape)); + //ret.set_shape(compute_output_shape(ret.shape)); return ret; + + int[] compute_output_shape(int[] inputshape) + { + if (inputshape == null || inputshape.Length == 0) + inputshape = new int[] {1}; + + if (inputshape.Skip(1).All(d => d > 0)) + { + int[] output_shape = new int[2]; + output_shape[0] = inputshape[0]; + output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc*rhs); //calculate size of all the rest dimensions + return output_shape; + } else + return new int[] {inputshape[0], -1}; //-1 == Binding.None + } } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index ec081cc4..985e3f73 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Operations; + namespace Tensorflow { public partial class tensorflow @@ -211,6 +213,36 @@ namespace Tensorflow /// public Tensor _clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) => gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max); + + /// + /// Clips tensor values to a specified min and max. + /// + /// + /// A Tensor. + /// + /// + /// A 0-D (scalar) Tensor, or a Tensor with the same shape + /// as t. The minimum value to clip by. + /// + /// + /// A 0-D (scalar) Tensor, or a Tensor with the same shape + /// as t. The maximum value to clip by. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ClipByValue'. + /// + /// + /// A clipped Tensor with the same shape as input 't'. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor t, this operation returns a tensor of the same type and + /// shape as t with its values clipped to clip_value_min and clip_value_max. + /// Any values less than clip_value_min are set to clip_value_min. Any values + /// greater than clip_value_max are set to clip_value_max. + /// + public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") + => gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name); public Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b); diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs index b553095e..2052de93 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs @@ -18,8 +18,8 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor convert_to_tensor(object value, - string name = null) => ops.convert_to_tensor(value, name: name); + public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid) + => ops.convert_to_tensor(value, dtype, name, preferred_dtype); public Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides = null, int begin_mask = 0, diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index ff383642..def78327 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -21,6 +21,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.Linq; +using NumSharp.Utilities; namespace Tensorflow { @@ -29,9 +30,37 @@ namespace Tensorflow /// public static partial class Binding { + private static string _tostring(object obj) + { + switch (obj) + { + case NDArray nd: + return nd.ToString(false); + case Array arr: + if (arr.Rank!=1 || arr.GetType().GetElementType()?.IsArray == true) + arr = Arrays.Flatten(arr); + var objs = toObjectArray(arr); + return $"[{string.Join(", ", objs.Select(_tostring))}]"; + default: + return obj?.ToString() ?? "null"; + } + + object[] toObjectArray(Array arr) + { + var len = arr.LongLength; + var ret = new object[len]; + for (long i = 0; i < len; i++) + { + ret[i] = arr.GetValue(i); + } + + return ret; + } + } + public static void print(object obj) { - Console.WriteLine(obj.ToString()); + Console.WriteLine(_tostring(obj)); } public static int len(object a) diff --git a/src/TensorFlowNET.Core/Binding.cs b/src/TensorFlowNET.Core/Binding.cs index f443f2eb..e3136f83 100644 --- a/src/TensorFlowNET.Core/Binding.cs +++ b/src/TensorFlowNET.Core/Binding.cs @@ -7,5 +7,16 @@ namespace Tensorflow public static partial class Binding { public static tensorflow tf { get; } = New(); + + /// + /// Alias to null, similar to python's None. + /// For TensorShape, please use Unknown + /// + public static readonly object None = null; + + /// + /// Used for TensorShape None + /// + public static readonly int Unknown = -1; } } diff --git a/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs b/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs index 80e1c305..788adda4 100644 --- a/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs +++ b/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs @@ -14,20 +14,192 @@ limitations under the License. ******************************************************************************/ +using System; +using static Tensorflow.Binding; + namespace Tensorflow.Operations.Activation { + public class sigmoid : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.sigmoid(x); + } + } + + public class tanh : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.tanh(x); + } + } + + public class leakyrelu : IActivation + { + private readonly float _alpha; + + public leakyrelu(float alpha = 0.3f) { + _alpha = alpha; + } + + public Tensor Activate(Tensor x, string name = null) + { + return nn_ops.leaky_relu(x, _alpha); + } + } + + public class elu : IActivation + { + private readonly float _alpha; + + public elu(float alpha = 0.1f) + { + _alpha = alpha; + } + + public Tensor Activate(Tensor x, string name = null) + { + var res = gen_ops.elu(x); + if (Math.Abs(_alpha - 0.1f) < 0.00001f) + { + return res; + } + + return array_ops.@where(x > 0, res, _alpha * res); + } + } + + public class softmax : IActivation + { + private readonly int _axis; + + /// Initializes a new instance of the class. + public softmax(int axis = -1) + { + _axis = axis; + } + + public Tensor Activate(Tensor x, string name = null) + { + return nn_ops.softmax(x, _axis); + } + } + + public class softplus : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return gen_ops.softplus(x); + } + } + + public class softsign : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return gen_ops.softsign(x); + } + } + + public class linear : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return x; + } + } + + + public class exponential : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.exp(x, name: name); + } + } + + public class relu : IActivation { - public Tensor Activate(Tensor features, string name = null) + private readonly float _threshold; + private readonly float _alpha; + private readonly float? _maxValue; + + public relu(float threshold = 0f, float alpha = 0.2f, float? max_value = null) + { + _threshold = threshold; + _alpha = alpha; + _maxValue = max_value; + } + + public Tensor Activate(Tensor x, string name = null) { - OpDefLibrary _op_def_lib = new OpDefLibrary(); + //based on keras/backend.py + if (Math.Abs(_alpha) > 0.000001f) + { + if (!_maxValue.HasValue && Math.Abs(_threshold) < 0.0001) + { + return nn_ops.leaky_relu(x, _alpha); + } + } + + Tensor negative_part; + if (Math.Abs(_threshold) > 0.000001f) + { + negative_part = gen_ops.relu(-x + _threshold); + } else + { + negative_part = gen_ops.relu(-x + _threshold); + } + + if (Math.Abs(_threshold) > 0.000001f) + { + x = x * math_ops.cast(tf.greater(x, _threshold), TF_DataType.TF_FLOAT); + } else if (Math.Abs(_maxValue.Value - 6f) < 0.0001f) + { + x = gen_ops.relu6(x); + } else + { + x = gen_ops.relu(x); + } + + bool clip_max = _maxValue.HasValue; + if (clip_max) + { + Tensor maxval = constant_op.constant(_maxValue, x.dtype.as_base_dtype()); + var zero = constant_op.constant(0.0f, x.dtype.as_base_dtype()); + x = gen_ops.clip_by_value(x, zero, maxval); + } - var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new + if (Math.Abs(_alpha) > 0.00001) { - features - }); + var a = constant_op.constant(_alpha, x.dtype.as_base_dtype()); + x -= a * negative_part; + } - return _op.outputs[0]; + return x; + } + } + + public class selu : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + const float alpha = 1.6732632423543772848170429916717f; + const float scale = 1.0507009873554804934193349852946f; + return scale * new elu(alpha).Activate(x, name); + } + } + + public class hard_sigmoid : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + x = (0.2 * x) + 0.5; + var zero = tf.convert_to_tensor(0.0f, x.dtype.as_base_dtype()); + var one = tf.convert_to_tensor(1.0f, x.dtype.as_base_dtype()); + return tf.clip_by_value(x, zero, one); } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 1fc95927..853255cb 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; +using NumSharp.Utilities; namespace Tensorflow { @@ -32,7 +33,23 @@ namespace Tensorflow /// /// Returns the size this shape represents. /// - public int size => shape.Size; + public int size + { + get + { + var dims = shape.Dimensions; + var computed = 1; + for (int i = 0; i < dims.Length; i++) + { + var val = dims[i]; + if (val <= 0) + continue; + computed *= val; + } + + return computed; + } + } public TensorShape(TensorShapeProto proto) { @@ -59,12 +76,30 @@ namespace Tensorflow switch (dims.Length) { case 0: shape = new Shape(new int[0]); break; - case 1: shape = Shape.Vector((int) dims[0]); break; + case 1: shape = Shape.Vector((int)dims[0]); break; case 2: shape = Shape.Matrix(dims[0], dims[1]); break; default: shape = new Shape(dims); break; } } + public TensorShape(int[][] dims) + { + if(dims.Length == 1) + { + switch (dims[0].Length) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int)dims[0][0]); break; + case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break; + default: shape = new Shape(dims[0]); break; + } + } + else + { + throw new NotImplementedException("TensorShape int[][] dims"); + } + } + /// /// /// @@ -188,6 +223,11 @@ namespace Tensorflow public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); - + + public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); + + public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); } } diff --git a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs new file mode 100644 index 00000000..efa7def3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs @@ -0,0 +1,60 @@ +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class TensorShapeTest + { + [TestMethod] + public void Case1() + { + int a = 2; + int b = 3; + var dims = new [] { Unknown, a, b}; + new TensorShape(dims).GetPrivate("shape").Should().BeShaped(-1, 2, 3); + } + + [TestMethod] + public void Case2() + { + int a = 2; + int b = 3; + var dims = new[] { Unknown, a, b}; + new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, 3); + } + + [TestMethod] + public void Case3() + { + int a = 2; + int b = Unknown; + var dims = new [] { Unknown, a, b}; + new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, -1); + } + + [TestMethod] + public void Case4() + { + TensorShape shape = (Unknown, Unknown); + shape.GetPrivate("shape").Should().BeShaped(-1, -1); + } + + [TestMethod] + public void Case5() + { + TensorShape shape = (1, Unknown, 3); + shape.GetPrivate("shape").Should().BeShaped(1, -1, 3); + } + + [TestMethod] + public void Case6() + { + TensorShape shape = (Unknown, 1, 2, 3, Unknown); + shape.GetPrivate("shape").Should().BeShaped(-1, 1, 2, 3, -1); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs index d533f128..fa8ec792 100644 --- a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs +++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs @@ -36,5 +36,23 @@ namespace TensorFlowNET.UnitTest.layers_test var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape()); new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw(); } + + [TestMethod] + public void Case4() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } + + [TestMethod] + public void Case5() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } } } \ No newline at end of file