diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs index ef3b76f7..1149b798 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs @@ -33,7 +33,16 @@ namespace Tensorflow.NumPy return Scalar(false); if(rhs is null) return Scalar(false); - return new NDArray(math_ops.equal(lhs, rhs)); + // TODO(Rinne): use np.allclose instead. + if (lhs.dtype.is_floating() || rhs.dtype.is_floating()) + { + var diff = tf.abs(lhs - rhs); + return new NDArray(gen_math_ops.less(diff, new NDArray(1e-5).astype(diff.dtype))); + } + else + { + return new NDArray(math_ops.equal(lhs, rhs)); + } } [AutoNumPy] public static NDArray operator !=(NDArray lhs, NDArray rhs) @@ -42,7 +51,15 @@ namespace Tensorflow.NumPy return Scalar(false); if(lhs is null || rhs is null) return Scalar(true); - return new NDArray(math_ops.not_equal(lhs, rhs)); + if (lhs.dtype.is_floating() || rhs.dtype.is_floating()) + { + var diff = tf.abs(lhs - rhs); + return new NDArray(gen_math_ops.greater_equal(diff, new NDArray(1e-5).astype(diff.dtype))); + } + else + { + return new NDArray(math_ops.not_equal(lhs, rhs)); + } } } } diff --git a/src/TensorflowNET.Hub/Tensorflow.Hub.csproj b/src/TensorflowNET.Hub/Tensorflow.Hub.csproj index e54ed3f8..fef8b34f 100644 --- a/src/TensorflowNET.Hub/Tensorflow.Hub.csproj +++ b/src/TensorflowNET.Hub/Tensorflow.Hub.csproj @@ -5,6 +5,23 @@ 10 enable 1.0.0 + TensorFlow.NET.Hub + Apache2.0 + true + true + Yaohui Liu, Haiping Chen + SciSharp STACK + true + Apache 2.0, Haiping Chen $([System.DateTime]::UtcNow.ToString(yyyy)) + https://github.com/SciSharp/TensorFlow.NET + git + http://scisharpstack.org + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + TensorFlow, SciSharp, Machine Learning, Deep Learning, TensorFlow Hub, TensorFlow.NET, TF.NET, AI + + Google's TensorFlow Hub full binding in .NET Standard. + A library for transfer learning with TensorFlow.NET. + diff --git a/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj b/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj index 6046dc53..35cb9f16 100644 --- a/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj +++ b/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj @@ -1,4 +1,4 @@ - + net6