diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
index 398fd508..bf8c358c 100644
--- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
@@ -18,10 +18,33 @@ namespace Tensorflow
{
public partial class tensorflow
{
+ public LinalgApi linalg { get; } = new LinalgApi();
+
+ public class LinalgApi
+ {
+ linalg_ops ops = new linalg_ops();
+
+ public Tensor eye(int num_rows,
+ int num_columns = -1,
+ TensorShape batch_shape = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ string name = null)
+ => ops.eye(num_rows, num_columns: num_columns, batch_shape: batch_shape, dtype: dtype, name: name);
+
+ public Tensor diag(Tensor diagonal, string name = null)
+ => gen_array_ops.diag(diagonal, name: name);
+
+ public Tensor matmul(Tensor a, Tensor b)
+ => math_ops.matmul(a, b);
+
+ public Tensor batch_matmul(Tensor x, Tensor y)
+ => gen_math_ops.batch_mat_mul(x, y);
+ }
+
public Tensor diag(Tensor diagonal, string name = null)
=> gen_array_ops.diag(diagonal, name: name);
- public Tensor matmul(Tensor a, Tensor b)
+ public Tensor matmul(Tensor a, Tensor b)
=> math_ops.matmul(a, b);
public Tensor batch_matmul(Tensor x, Tensor y)
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index a4335046..ac101061 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -599,6 +599,46 @@ namespace Tensorflow
public static Tensor invert_permutation(Tensor x, string name = null)
=> gen_array_ops.invert_permutation(x, name: name);
+ public static Tensor matrix_diag(Tensor diagonal,
+ string name = "diag",
+ int k = 0,
+ int num_rows = -1,
+ int num_cols = -1,
+ float padding_value = 0,
+ string align = "RIGHT_LEFT")
+ {
+ if (tf.context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
+ "MatrixDiagV3", name,
+ null,
+ diagonal, k, num_rows, num_cols, padding_value,
+ "align", align);
+ return results[0];
+ }
+
+ throw new NotImplementedException("");
+ }
+
+ public static Tensor matrix_set_diag(Tensor input,
+ Tensor diagonal,
+ string name = "set_diag",
+ int k = 0,
+ string align = "RIGHT_LEFT")
+ {
+ if (tf.context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
+ "MatrixSetDiagV3", name,
+ null,
+ input, diagonal, k,
+ "align", align);
+ return results[0];
+ }
+
+ throw new NotImplementedException("");
+ }
+
///
/// Computes the shape of a broadcast given symbolic shapes.
/// When shape_x and shape_y are Tensors representing shapes(i.e.the result of
diff --git a/src/TensorFlowNET.Core/Operations/linalg_ops.cs b/src/TensorFlowNET.Core/Operations/linalg_ops.cs
new file mode 100644
index 00000000..cbbe262a
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/linalg_ops.cs
@@ -0,0 +1,43 @@
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+
+namespace Tensorflow
+{
+ public class linalg_ops
+ {
+ public Tensor eye(int num_rows,
+ int num_columns = -1,
+ TensorShape batch_shape = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ string name = null)
+ {
+ return tf_with(ops.name_scope(name, default_name: "eye", new { num_rows, num_columns, batch_shape }), scope =>
+ {
+ if (num_columns == -1)
+ num_columns = num_rows;
+
+ bool is_square = num_columns == num_rows;
+ var diag_size = Math.Min(num_rows, num_columns);
+ if (batch_shape == null)
+ batch_shape = new TensorShape(new int[0]);
+ var diag_shape = batch_shape.dims.concat(new[] { diag_size });
+
+ int[] shape = null;
+ if (!is_square)
+ shape = batch_shape.dims.concat(new[] { num_rows, num_columns });
+
+ var diag_ones = array_ops.ones(diag_shape, dtype: dtype);
+ if (is_square)
+ return array_ops.matrix_diag(diag_ones);
+ else
+ {
+ var zero_matrix = array_ops.zeros(shape, dtype: dtype);
+ return array_ops.matrix_set_diag(zero_matrix, diag_ones);
+ }
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs
new file mode 100644
index 00000000..7f4fb27d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs
@@ -0,0 +1,40 @@
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public partial class TensorShape
+ {
+ public static implicit operator TensorShape(Shape shape) => new TensorShape((int[])shape.Dimensions.Clone());
+ public static implicit operator Shape(TensorShape shape) => new Shape((int[])shape.dims.Clone());
+
+ public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
+ public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims);
+
+ public static explicit operator int(TensorShape shape) => shape.size;
+ public static implicit operator TensorShape(int dim) => new TensorShape(dim);
+
+ public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0);
+ public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
+
+ public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0);
+ public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
+
+ public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
+
+ public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
+
+ public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
+
+ public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
+
+ public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs
new file mode 100644
index 00000000..a8843e11
--- /dev/null
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs
@@ -0,0 +1,32 @@
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Tensorflow
+{
+ public partial class TensorShape
+ {
+ public override bool Equals(Object obj)
+ {
+ switch (obj)
+ {
+ case TensorShape shape1:
+ return Enumerable.SequenceEqual(shape1.dims, dims);
+ default:
+ return false;
+ }
+ }
+
+ /*public static bool operator ==(TensorShape shape1, TensorShape shape2)
+ {
+ return false;
+ }
+
+ public static bool operator !=(TensorShape shape1, TensorShape shape2)
+ {
+ return false;
+ }*/
+ }
+}
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 07215701..ad00665c 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -1,6 +1,7 @@
using NumSharp;
using System;
using System.Collections.Generic;
+using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
@@ -12,7 +13,7 @@ namespace Tensorflow
/// Represents the shape of a `Tensor`.
///
/// https://www.tensorflow.org/api_docs/python/tf/TensorShape
- public class TensorShape
+ public partial class TensorShape
{
private readonly Shape shape;
@@ -255,35 +256,5 @@ namespace Tensorflow
{
return shape.ToString();
}
-
- public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
- public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());
-
- public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
- public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims);
-
- public static explicit operator int(TensorShape shape) => shape.size;
- public static implicit operator TensorShape(int dim) => new TensorShape(dim);
-
- public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0);
- public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
-
- public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0);
- public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
-
- public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0);
- public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
-
- public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0);
- public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
-
- public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
- public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
-
- public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
- public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
-
- public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
- public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
}
}
diff --git a/test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs b/test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs
new file mode 100644
index 00000000..54302993
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs
@@ -0,0 +1,24 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.UnitTest.TF_API
+{
+ [TestClass]
+ public class LinalgTest
+ {
+ [TestMethod]
+ public void EyeTest()
+ {
+ var tensor = tf.linalg.eye(3);
+
+ Assert.AreEqual((3, 3), tensor.TensorShape);
+
+ Assert.AreEqual(0.0f, (float)tensor[2, 0]);
+ Assert.AreEqual(0.0f, (float)tensor[2, 1]);
+ Assert.AreEqual(1.0f, (float)tensor[2, 2]);
+ }
+ }
+}