From 5821275145e5123d1acbc4094acd2baef06ac138 Mon Sep 17 00:00:00 2001 From: Superpiffer Date: Fri, 10 Feb 2023 16:21:26 +0100 Subject: [PATCH] Reimplemented NDArray == and != operators, handling null values. Added unit tests. --- .../NumPy/NDArray.Operators.cs | 22 ++++++++++--- .../NumPy/OperatorsTest.cs | 33 +++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 test/TensorFlowNET.UnitTest/NumPy/OperatorsTest.cs diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs index 7168678a..ef3b76f7 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs @@ -25,10 +25,24 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray operator -(NDArray lhs) => new NDArray(gen_math_ops.neg(lhs)); [AutoNumPy] - public static NDArray operator ==(NDArray lhs, NDArray rhs) - => rhs is null ? Scalar(false) : new NDArray(math_ops.equal(lhs, rhs)); + public static NDArray operator ==(NDArray lhs, NDArray rhs) + { + if(ReferenceEquals(lhs, rhs)) + return Scalar(true); + if(lhs is null) + return Scalar(false); + if(rhs is null) + return Scalar(false); + return new NDArray(math_ops.equal(lhs, rhs)); + } [AutoNumPy] - public static NDArray operator !=(NDArray lhs, NDArray rhs) - => new NDArray(math_ops.not_equal(lhs, rhs)); + public static NDArray operator !=(NDArray lhs, NDArray rhs) + { + if(ReferenceEquals(lhs, rhs)) + return Scalar(false); + if(lhs is null || rhs is null) + return Scalar(true); + return new NDArray(math_ops.not_equal(lhs, rhs)); + } } } diff --git a/test/TensorFlowNET.UnitTest/NumPy/OperatorsTest.cs b/test/TensorFlowNET.UnitTest/NumPy/OperatorsTest.cs new file mode 100644 index 00000000..e4989a1d --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/OperatorsTest.cs @@ -0,0 +1,33 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + [TestClass] + public class OperatorsTest + { + [TestMethod] + public void EqualToOperator() + { + NDArray n1 = null; + NDArray n2 = new NDArray(1); + + Assert.IsTrue(n1 == null); + Assert.IsFalse(n2 == null); + Assert.IsFalse(n1 == 1); + Assert.IsTrue(n2 == 1); + } + + [TestMethod] + public void NotEqualToOperator() + { + NDArray n1 = null; + NDArray n2 = new NDArray(1); + + Assert.IsFalse(n1 != null); + Assert.IsTrue(n2 != null); + Assert.IsTrue(n1 != 1); + Assert.IsFalse(n2 != 1); + } + } +}