diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs index 398fd508..bf8c358c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -18,10 +18,33 @@ namespace Tensorflow { public partial class tensorflow { + public LinalgApi linalg { get; } = new LinalgApi(); + + public class LinalgApi + { + linalg_ops ops = new linalg_ops(); + + public Tensor eye(int num_rows, + int num_columns = -1, + TensorShape batch_shape = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + string name = null) + => ops.eye(num_rows, num_columns: num_columns, batch_shape: batch_shape, dtype: dtype, name: name); + + public Tensor diag(Tensor diagonal, string name = null) + => gen_array_ops.diag(diagonal, name: name); + + public Tensor matmul(Tensor a, Tensor b) + => math_ops.matmul(a, b); + + public Tensor batch_matmul(Tensor x, Tensor y) + => gen_math_ops.batch_mat_mul(x, y); + } + public Tensor diag(Tensor diagonal, string name = null) => gen_array_ops.diag(diagonal, name: name); - public Tensor matmul(Tensor a, Tensor b) + public Tensor matmul(Tensor a, Tensor b) => math_ops.matmul(a, b); public Tensor batch_matmul(Tensor x, Tensor y) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index a4335046..ac101061 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -599,6 +599,46 @@ namespace Tensorflow public static Tensor invert_permutation(Tensor x, string name = null) => gen_array_ops.invert_permutation(x, name: name); + public static Tensor matrix_diag(Tensor diagonal, + string name = "diag", + int k = 0, + int num_rows = -1, + int num_cols = -1, + float padding_value = 0, + string align = "RIGHT_LEFT") + { + if (tf.context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "MatrixDiagV3", name, + null, + diagonal, k, num_rows, num_cols, padding_value, + "align", align); + return results[0]; + } + + throw new NotImplementedException(""); + } + + public static Tensor matrix_set_diag(Tensor input, + Tensor diagonal, + string name = "set_diag", + int k = 0, + string align = "RIGHT_LEFT") + { + if (tf.context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "MatrixSetDiagV3", name, + null, + input, diagonal, k, + "align", align); + return results[0]; + } + + throw new NotImplementedException(""); + } + /// /// Computes the shape of a broadcast given symbolic shapes. /// When shape_x and shape_y are Tensors representing shapes(i.e.the result of diff --git a/src/TensorFlowNET.Core/Operations/linalg_ops.cs b/src/TensorFlowNET.Core/Operations/linalg_ops.cs new file mode 100644 index 00000000..cbbe262a --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/linalg_ops.cs @@ -0,0 +1,43 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class linalg_ops + { + public Tensor eye(int num_rows, + int num_columns = -1, + TensorShape batch_shape = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + string name = null) + { + return tf_with(ops.name_scope(name, default_name: "eye", new { num_rows, num_columns, batch_shape }), scope => + { + if (num_columns == -1) + num_columns = num_rows; + + bool is_square = num_columns == num_rows; + var diag_size = Math.Min(num_rows, num_columns); + if (batch_shape == null) + batch_shape = new TensorShape(new int[0]); + var diag_shape = batch_shape.dims.concat(new[] { diag_size }); + + int[] shape = null; + if (!is_square) + shape = batch_shape.dims.concat(new[] { num_rows, num_columns }); + + var diag_ones = array_ops.ones(diag_shape, dtype: dtype); + if (is_square) + return array_ops.matrix_diag(diag_ones); + else + { + var zero_matrix = array_ops.zeros(shape, dtype: dtype); + return array_ops.matrix_set_diag(zero_matrix, diag_ones); + } + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs new file mode 100644 index 00000000..7f4fb27d --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs @@ -0,0 +1,40 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class TensorShape + { + public static implicit operator TensorShape(Shape shape) => new TensorShape((int[])shape.Dimensions.Clone()); + public static implicit operator Shape(TensorShape shape) => new Shape((int[])shape.dims.Clone()); + + public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes + public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); + + public static explicit operator int(TensorShape shape) => shape.size; + public static implicit operator TensorShape(int dim) => new TensorShape(dim); + + public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); + public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); + + public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); + public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); + + public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + + public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); + + public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); + + public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); + + public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs new file mode 100644 index 00000000..a8843e11 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs @@ -0,0 +1,32 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow +{ + public partial class TensorShape + { + public override bool Equals(Object obj) + { + switch (obj) + { + case TensorShape shape1: + return Enumerable.SequenceEqual(shape1.dims, dims); + default: + return false; + } + } + + /*public static bool operator ==(TensorShape shape1, TensorShape shape2) + { + return false; + } + + public static bool operator !=(TensorShape shape1, TensorShape shape2) + { + return false; + }*/ + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 07215701..ad00665c 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -1,6 +1,7 @@ using NumSharp; using System; using System.Collections.Generic; +using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; @@ -12,7 +13,7 @@ namespace Tensorflow /// Represents the shape of a `Tensor`. /// /// https://www.tensorflow.org/api_docs/python/tf/TensorShape - public class TensorShape + public partial class TensorShape { private readonly Shape shape; @@ -255,35 +256,5 @@ namespace Tensorflow { return shape.ToString(); } - - public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); - public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); - - public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes - public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); - - public static explicit operator int(TensorShape shape) => shape.size; - public static implicit operator TensorShape(int dim) => new TensorShape(dim); - - public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); - public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); - - public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); - public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); - - public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); - public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); - - public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); - public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); - - public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); - public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); - - public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0); - public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); - - public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); - public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); } } diff --git a/test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs b/test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs new file mode 100644 index 00000000..54302993 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs @@ -0,0 +1,24 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.UnitTest.TF_API +{ + [TestClass] + public class LinalgTest + { + [TestMethod] + public void EyeTest() + { + var tensor = tf.linalg.eye(3); + + Assert.AreEqual((3, 3), tensor.TensorShape); + + Assert.AreEqual(0.0f, (float)tensor[2, 0]); + Assert.AreEqual(0.0f, (float)tensor[2, 1]); + Assert.AreEqual(1.0f, (float)tensor[2, 2]); + } + } +}